diff --git a/README.md b/README.md index 9e749c5..055bf9c 100644 --- a/README.md +++ b/README.md @@ -9,14 +9,14 @@ A basic routine is as simple as: ```py import jax.numpy as jnp import exciting_environments as excenvs +from exciting_environments import EnvironmentType from exciting_environments.utils import MinMaxNormalization -env = excenvs.make( - "Pendulum-v0", +env = EnvironmentType.PENDULUM.make( batch_size=5, action_normalizations={"torque": MinMaxNormalization(min=-15,max=15)}, tau=2e-2 -) +) obs, state = env.vmap_reset() actions = jnp.linspace(start=-1, stop=1, num=1000)[None, :, None] @@ -45,11 +45,11 @@ alternatively, simulate full trajectories: ```py import jax.numpy as jnp import exciting_environments as excenvs +from exciting_environments import EnvironmentType from exciting_environments.utils import MinMaxNormalization import diffrax -env = excenvs.make( - "Pendulum-v0", +env = EnvironmentType.PENDULUM.make( solver=diffrax.Tsit5(), batch_size=5, action_normalizations={"torque": MinMaxNormalization(min=-15,max=15)}, @@ -76,4 +76,4 @@ which produces $5$ identical trajectories in parallel as well: ![](https://github.com/ExcitingSystems/exciting-environments/blob/main/fig/excenvs_pendulum_simulation_example_advanced.png?raw=true) Note that in this case the Tsit5 ODE solver instead of the default explicit Euler is used. -All solvers used here are from the diffrax library (https://docs.kidger.site/diffrax/). \ No newline at end of file +All solvers used here are from the diffrax library (https://docs.kidger.site/diffrax/). diff --git a/exciting_environments/__init__.py b/exciting_environments/__init__.py index e377e95..a557463 100644 --- a/exciting_environments/__init__.py +++ b/exciting_environments/__init__.py @@ -5,6 +5,6 @@ from .pendulum import Pendulum from .fluid_tank import FluidTank from .acrobot import Acrobot -from .registration import make +from .registration import EnvironmentType from .gym_wrapper import GymWrapper from .mujoco_wrapper import MujucoWrapper diff --git a/exciting_environments/gym_wrapper.py b/exciting_environments/gym_wrapper.py index b315cac..98e3e32 100644 --- a/exciting_environments/gym_wrapper.py +++ b/exciting_environments/gym_wrapper.py @@ -6,8 +6,9 @@ from functools import partial import chex from abc import ABC -from exciting_environments import spaces -from exciting_environments.registration import make +from exciting_environments import spaces, EnvironmentType + +# from exciting_environments.registration import make class GymWrapper(ABC): @@ -58,9 +59,9 @@ def __init__( self.generate_terminated = self.env.generate_terminated @classmethod - def from_name(cls, env_id: str, **env_kwargs): - """Creates GymWrapper with environment based on passed env_id.""" - env = make(env_id, **env_kwargs) + def from_env(cls, env_type: EnvironmentType, **env_kwargs): + """Creates GymWrapper with environment based on passed EnvironmentType.""" + env = env_type.make(**env_kwargs) return cls(env) def step(self, action): diff --git a/exciting_environments/pmsm/__init__.py b/exciting_environments/pmsm/__init__.py index ef8bc91..52addba 100644 --- a/exciting_environments/pmsm/__init__.py +++ b/exciting_environments/pmsm/__init__.py @@ -1,2 +1,2 @@ -from .motor_parameters import default_params +from .motor_parameters import MotorVariant from .pmsm_env import PMSM # , PMSM_Physical diff --git a/exciting_environments/pmsm/motor_parameters.py b/exciting_environments/pmsm/motor_parameters.py index 4319974..1057e98 100644 --- a/exciting_environments/pmsm/motor_parameters.py +++ b/exciting_environments/pmsm/motor_parameters.py @@ -10,6 +10,8 @@ from exciting_environments.utils import MinMaxNormalization from copy import deepcopy +from enum import Enum + @jdc.pytree_dataclass class PhysicalNormalizations: @@ -147,21 +149,15 @@ def default_soft_constraints(self, state, action_norm, env_properties): ) -def default_params(name): - """ - Returns default parameters for specified motor configurations. - - Args: - name (str): Name of the motor ("BRUSA" or "SEW"). - - Returns: - MotorConfig: Configuration containing physical constraints, action constraints, static parameters, and LUT data. - """ - if name is None: - return deepcopy(DEFAULT) - elif name == "BRUSA": - return deepcopy(BRUSA) - elif name == "SEW": - return deepcopy(SEW) - else: - raise ValueError(f"Motor name {name} is not known.") +class MotorVariant(Enum): + DEFAULT = "DEFAULT" + BRUSA = "BRUSA" + SEW = "SEW" + + def get_params(self): + if self is MotorVariant.BRUSA: + return deepcopy(BRUSA) + elif self is MotorVariant.SEW: + return deepcopy(SEW) + else: + return deepcopy(DEFAULT) diff --git a/exciting_environments/pmsm/pmsm_env.py b/exciting_environments/pmsm/pmsm_env.py index 18abd53..212a602 100644 --- a/exciting_environments/pmsm/pmsm_env.py +++ b/exciting_environments/pmsm/pmsm_env.py @@ -14,7 +14,7 @@ from copy import deepcopy from exciting_environments import CoreEnvironment -from exciting_environments.pmsm import default_params +from exciting_environments.pmsm import MotorVariant # only for alpha/beta -> abc @@ -117,7 +117,7 @@ def __init__( self, batch_size: int = 8, saturated=False, - LUT_motor_name: str = None, + motor_variant: MotorVariant = MotorVariant.DEFAULT, physical_normalizations: dict = None, action_normalizations: dict = None, soft_constraints: Callable = None, @@ -130,7 +130,7 @@ def __init__( Args: batch_size (int): Number of parallel environment simulations. Default: 8 saturated (bool): Permanent magnet flux linkages and inductances are taken from LUT_motor_name specific LUTs. Default: False - LUT_motor_name (str): Sets physical_normalizations, action_normalizations, soft_constraints and static_params to default values for the passed motor name and stores associated LUTs for the possible saturated case. Needed if saturated==True. + motor_variant (MotorVariant): Sets physical_normalizations, action_normalizations, soft_constraints and static_params to default values for the passed motor variant and stores associated LUTs for the possible saturated case. Needed if saturated==True. physical_normalizations (dict): min-max normalization values of the physical state of the environment. u_d_buffer (MinMaxNormalization): Direct share of the delayed action due to system deadtime. Default: min=-2 * 400 / 3, max=2 * 400 / 3 u_q_buffer (MinMaxNormalization): Quadrature share of the delayed action due to system deadtime. Default: min=-2 * 400 / 3, max=2 * 400 / 3 @@ -161,8 +161,8 @@ def __init__( self.tau = tau self._solver = solver - if LUT_motor_name is not None: - motor_params = deepcopy(default_params(LUT_motor_name)) + if motor_variant != MotorVariant.DEFAULT: + motor_params = motor_variant.get_params() default_physical_normalizations = motor_params.physical_normalizations.__dict__ default_action_normalizations = motor_params.action_normalizations.__dict__ default_static_params = motor_params.static_params.__dict__ @@ -187,7 +187,10 @@ def __init__( else: if saturated: - raise Exception("LUT_motor_name is needed to load LUTs.") + raise ValueError( + f"MotorVariant '{motor_variant.value}' is not allowed for saturated LUTs. " + "Use a specific motor variant. DEFAULT is only valid for saturated=False." + ) saturated_quants = [ "L_dd", @@ -198,7 +201,7 @@ def __init__( "Psi_q", ] - motor_params = deepcopy(default_params(LUT_motor_name)) + motor_params = motor_variant.get_params() default_physical_normalizations = motor_params.physical_normalizations.__dict__ default_action_normalizations = motor_params.action_normalizations.__dict__ default_static_params = motor_params.static_params.__dict__ diff --git a/exciting_environments/registration.py b/exciting_environments/registration.py index 84add20..797139e 100644 --- a/exciting_environments/registration.py +++ b/exciting_environments/registration.py @@ -6,29 +6,27 @@ PMSM, Acrobot, ) - - -def make(env_id: str, **env_kwargs): - if env_id == "CartPole-v0": - env = CartPole(**env_kwargs) - - elif env_id == "MassSpringDamper-v0": - env = MassSpringDamper(**env_kwargs) - - elif env_id == "Pendulum-v0": - env = Pendulum(**env_kwargs) - - elif env_id == "FluidTank-v0": - env = FluidTank(**env_kwargs) - - elif env_id == "PMSM-v0": - env = PMSM(**env_kwargs) - - elif env_id == "Acrobot-v0": - env = Acrobot(**env_kwargs) - - else: - print(f"No existing environments got env_id ={env_id}") - env = None - - return env +from enum import Enum + + +class EnvironmentType(Enum): + CART_POLE = "CartPole-v0" + MASS_SPRING_DAMPER = "MassSpringDamper-v0" + PENDULUM = "Pendulum-v0" + FLUID_TANK = "FluidTank-v0" + PMSM = "PMSM-v0" + ACROBOT = "Acrobot-v0" + + def make(self, **env_kwargs): + env_map = { + EnvironmentType.CART_POLE: CartPole, + EnvironmentType.MASS_SPRING_DAMPER: MassSpringDamper, + EnvironmentType.PENDULUM: Pendulum, + EnvironmentType.FLUID_TANK: FluidTank, + EnvironmentType.PMSM: PMSM, + EnvironmentType.ACROBOT: Acrobot, + } + cls = env_map.get(self) + if cls is None: + raise ValueError(f"Unknown environment: {self}") + return cls(**env_kwargs) diff --git a/tests/envs/acrobot/test_acrobot.py b/tests/envs/acrobot/test_acrobot.py index 66a1f5a..31bc142 100644 --- a/tests/envs/acrobot/test_acrobot.py +++ b/tests/envs/acrobot/test_acrobot.py @@ -4,11 +4,13 @@ import jax.numpy as jnp import numpy as np import diffrax +from exciting_environments import EnvironmentType from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json from pathlib import Path import pickle import os + jax.config.update("jax_enable_x64", True) @@ -33,7 +35,7 @@ def test_default_initialization(): "omega_1": MinMaxNormalization(min=-10, max=10), "omega_2": MinMaxNormalization(min=-10, max=10), } - env = excenvs.make("Acrobot-v0", batch_size=batch_size) + env = EnvironmentType.ACROBOT.make(batch_size=batch_size) for key, value in params.items(): env_value = getattr(env.env_properties.static_params, key) if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray): @@ -101,8 +103,7 @@ def test_custom_initialization(): "omega_1": MinMaxNormalization(min=-55, max=10), "omega_2": MinMaxNormalization(min=-10, max=30), } - env = excenvs.make( - "Acrobot-v0", + env = EnvironmentType.ACROBOT.make( batch_size=batch_size, static_params=params, physical_normalizations=physical_normalizations, @@ -159,8 +160,7 @@ def test_step_results(): loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = ( load_sim_properties_from_json(file_path) ) - env = excenvs.make( - "Acrobot-v0", + env = EnvironmentType.ACROBOT.make( tau=loaded_tau, solver=diffrax.Euler(), static_params=loaded_params, diff --git a/tests/envs/cartpole/test_cartpole.py b/tests/envs/cartpole/test_cartpole.py index b2cc95d..e7834ab 100644 --- a/tests/envs/cartpole/test_cartpole.py +++ b/tests/envs/cartpole/test_cartpole.py @@ -4,13 +4,16 @@ import jax.numpy as jnp import numpy as np import diffrax -from exciting_environments.utils import MinMaxNormalization,load_sim_properties_from_json +from exciting_environments import EnvironmentType +from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json from pathlib import Path import pickle import os + jax.config.update("jax_enable_x64", True) + def test_default_initialization(): """Ensure default static parameters and normalizations are not changed by accident.""" batch_size = 4 @@ -29,7 +32,7 @@ def test_default_initialization(): "theta": MinMaxNormalization(min=-jnp.pi, max=jnp.pi), "omega": MinMaxNormalization(min=-8, max=8), } - env = excenvs.make("CartPole-v0", batch_size=batch_size) + env = EnvironmentType.CART_POLE.make(batch_size=batch_size) for key, value in params.items(): env_value = getattr(env.env_properties.static_params, key) if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray): @@ -94,8 +97,7 @@ def test_static_parameters_initialization(): "m_c": jnp.repeat(1, batch_size), "g": 35.81, } - env = excenvs.make( - "CartPole-v0", + env = EnvironmentType.CART_POLE.make( batch_size=batch_size, static_params=params, physical_normalizations=physical_normalizations, @@ -150,9 +152,10 @@ def test_static_parameters_initialization(): def test_step_results(): data_dir = os.path.join(Path(__file__).parent, "data") file_path = os.path.join(data_dir, "sim_properties.json") - loaded_params,loaded_action_normalizations,loaded_physical_normalizations,loaded_tau=load_sim_properties_from_json(file_path) - env = excenvs.make( - "CartPole-v0", + loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = ( + load_sim_properties_from_json(file_path) + ) + env = EnvironmentType.CART_POLE.make( tau=loaded_tau, solver=diffrax.Euler(), static_params=loaded_params, @@ -170,4 +173,4 @@ def test_step_results(): obs, state = env.step(state, action, env.env_properties) generated_observations.append(obs) generated_observations = jnp.array(generated_observations) - assert jnp.allclose(generated_observations, stored_observations, 1e-16), "Step function generates different data" \ No newline at end of file + assert jnp.allclose(generated_observations, stored_observations, 1e-16), "Step function generates different data" diff --git a/tests/envs/fluid_tank/test_fluid_tank.py b/tests/envs/fluid_tank/test_fluid_tank.py index 03b6718..2370b39 100644 --- a/tests/envs/fluid_tank/test_fluid_tank.py +++ b/tests/envs/fluid_tank/test_fluid_tank.py @@ -4,7 +4,8 @@ import jax.numpy as jnp import numpy as np import diffrax -from exciting_environments.utils import MinMaxNormalization,load_sim_properties_from_json +from exciting_environments import EnvironmentType +from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json from pathlib import Path import pickle import os @@ -18,7 +19,7 @@ def test_default_initialization(): params = {"base_area": jnp.pi, "orifice_area": jnp.pi * 0.1**2, "c_d": 0.6, "g": 9.81} action_normalizations = {"inflow": MinMaxNormalization(min=0, max=0.2)} physical_normalizations = {"height": MinMaxNormalization(min=0, max=3)} - env = excenvs.make("FluidTank-v0", batch_size=batch_size) + env = EnvironmentType.FLUID_TANK.make(batch_size=batch_size) for key, value in params.items(): env_value = getattr(env.env_properties.static_params, key) if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray): @@ -71,8 +72,7 @@ def test_custom_initialization(): params = {"base_area": jnp.repeat(jnp.pi, batch_size), "orifice_area": jnp.pi * 0.1**2, "c_d": 0.6, "g": 9.81} action_normalizations = {"inflow": MinMaxNormalization(min=jnp.repeat(0.02, batch_size), max=0.3)} physical_normalizations = {"height": MinMaxNormalization(min=1, max=5)} - env = excenvs.make( - "FluidTank-v0", + env = EnvironmentType.FLUID_TANK.make( batch_size=batch_size, static_params=params, physical_normalizations=physical_normalizations, @@ -127,9 +127,10 @@ def test_custom_initialization(): def test_step_results(): data_dir = os.path.join(Path(__file__).parent, "data") file_path = os.path.join(data_dir, "sim_properties.json") - loaded_params,loaded_action_normalizations,loaded_physical_normalizations,loaded_tau=load_sim_properties_from_json(file_path) - env = excenvs.make( - "FluidTank-v0", + loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = ( + load_sim_properties_from_json(file_path) + ) + env = EnvironmentType.FLUID_TANK.make( tau=loaded_tau, solver=diffrax.Euler(), static_params=loaded_params, diff --git a/tests/envs/mass_spring_damper/test_mass_spring_damper.py b/tests/envs/mass_spring_damper/test_mass_spring_damper.py index 3ce2020..956772f 100644 --- a/tests/envs/mass_spring_damper/test_mass_spring_damper.py +++ b/tests/envs/mass_spring_damper/test_mass_spring_damper.py @@ -4,6 +4,7 @@ import jax.numpy as jnp import numpy as np import diffrax +from exciting_environments import EnvironmentType from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json from pathlib import Path import pickle @@ -11,6 +12,7 @@ jax.config.update("jax_enable_x64", True) + def test_default_initialization(): """Ensure default static parameters and normalizations are not changed by accident.""" batch_size = 4 @@ -20,7 +22,7 @@ def test_default_initialization(): "deflection": MinMaxNormalization(min=-10, max=10), "velocity": MinMaxNormalization(min=-10, max=10), } - env = excenvs.make("MassSpringDamper-v0", batch_size=batch_size) + env = EnvironmentType.MASS_SPRING_DAMPER.make(batch_size=batch_size) for key, value in params.items(): env_value = getattr(env.env_properties.static_params, key) if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray): @@ -76,8 +78,7 @@ def test_custom_initialization(): } action_normalizations = {"force": MinMaxNormalization(min=-10, max=20)} params = {"k": jnp.repeat(10, batch_size), "m": 5, "d": 2} - env = excenvs.make( - "MassSpringDamper-v0", + env = EnvironmentType.MASS_SPRING_DAMPER.make( batch_size=batch_size, static_params=params, physical_normalizations=physical_normalizations, @@ -132,9 +133,10 @@ def test_custom_initialization(): def test_step_results(): data_dir = os.path.join(Path(__file__).parent, "data") file_path = os.path.join(data_dir, "sim_properties.json") - loaded_params,loaded_action_normalizations,loaded_physical_normalizations,loaded_tau=load_sim_properties_from_json(file_path) - env = excenvs.make( - "MassSpringDamper-v0", + loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = ( + load_sim_properties_from_json(file_path) + ) + env = EnvironmentType.MASS_SPRING_DAMPER.make( tau=loaded_tau, solver=diffrax.Euler(), static_params=loaded_params, diff --git a/tests/envs/pendulum/test_pendulum.py b/tests/envs/pendulum/test_pendulum.py index 50cefd0..865f757 100644 --- a/tests/envs/pendulum/test_pendulum.py +++ b/tests/envs/pendulum/test_pendulum.py @@ -4,6 +4,7 @@ import jax.numpy as jnp import numpy as np import diffrax +from exciting_environments import EnvironmentType from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json from pathlib import Path import pickle @@ -21,7 +22,7 @@ def test_default_initialization(): "theta": MinMaxNormalization(min=-jnp.pi, max=jnp.pi), "omega": MinMaxNormalization(min=-10, max=10), } - env = excenvs.make("Pendulum-v0", batch_size=batch_size) + env = EnvironmentType.PENDULUM.make(batch_size=batch_size) for key, value in params.items(): env_value = getattr(env.env_properties.static_params, key) if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray): @@ -77,8 +78,7 @@ def test_custom_initialization(): } action_normalizations = {"torque": MinMaxNormalization(min=-10, max=10)} params = {"l": jnp.repeat(1, batch_size), "g": 9.81, "m": 1} - env = excenvs.make( - "Pendulum-v0", + env = EnvironmentType.PENDULUM.make( batch_size=batch_size, static_params=params, physical_normalizations=physical_normalizations, @@ -132,9 +132,10 @@ def test_custom_initialization(): def test_step_results(): data_dir = os.path.join(Path(__file__).parent, "data") file_path = os.path.join(data_dir, "sim_properties.json") - loaded_params,loaded_action_normalizations,loaded_physical_normalizations,loaded_tau=load_sim_properties_from_json(file_path) - env = excenvs.make( - "Pendulum-v0", + loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = ( + load_sim_properties_from_json(file_path) + ) + env = EnvironmentType.PENDULUM.make( tau=loaded_tau, solver=diffrax.Euler(), static_params=loaded_params, diff --git a/tests/envs/pmsm/motor_parameters.py b/tests/envs/pmsm/motor_parameters.py index 69559ae..7aca978 100644 --- a/tests/envs/pmsm/motor_parameters.py +++ b/tests/envs/pmsm/motor_parameters.py @@ -10,6 +10,9 @@ from exciting_environments.utils import MinMaxNormalization from copy import deepcopy +from enum import Enum + + @jdc.pytree_dataclass class PhysicalNormalizations: u_d_buffer: float @@ -111,7 +114,7 @@ def default_soft_constraints(self, state, action_norm, env_properties): deadtime=1, ), default_soft_constraints=default_soft_constraints, - pmsm_lut=None, + pmsm_lut=None, ) DEFAULT = MotorParams( @@ -142,21 +145,15 @@ def default_soft_constraints(self, state, action_norm, env_properties): ) -def default_params(name): - """ - Returns default parameters for specified motor configurations. - - Args: - name (str): Name of the motor ("BRUSA" or "SEW"). - - Returns: - MotorConfig: Configuration containing physical constraints, action constraints, static parameters, and LUT data. - """ - if name is None: - return deepcopy(DEFAULT) - elif name == "BRUSA": - return deepcopy(BRUSA) - elif name == "SEW": - return deepcopy(SEW) - else: - raise ValueError(f"Motor name {name} is not known.") +class MotorVariant(Enum): + DEFAULT = "DEFAULT" + BRUSA = "BRUSA" + SEW = "SEW" + + def get_params(self): + if self is MotorVariant.BRUSA: + return deepcopy(BRUSA) + elif self is MotorVariant.SEW: + return deepcopy(SEW) + else: + return deepcopy(DEFAULT) diff --git a/tests/envs/pmsm/test_pmsm.py b/tests/envs/pmsm/test_pmsm.py index 53659f2..d3cb8c1 100644 --- a/tests/envs/pmsm/test_pmsm.py +++ b/tests/envs/pmsm/test_pmsm.py @@ -4,26 +4,27 @@ import jax.numpy as jnp import numpy as np import diffrax +from exciting_environments import EnvironmentType from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json from pathlib import Path -from motor_parameters import default_params +from motor_parameters import MotorVariant import pickle import os jax.config.update("jax_enable_x64", True) -motor_names = ["BRUSA", "SEW", None] +motor_variants = list(MotorVariant) -@pytest.mark.parametrize("motor_name", motor_names) -def test_default_initialization(motor_name): +@pytest.mark.parametrize("motor_variant", motor_variants) +def test_default_initialization(motor_variant): """Ensure default static parameters and normalizations are not changed by accident.""" - motor_params = default_params(motor_name) + motor_params = motor_variant.get_params() physical_normalizations = motor_params.physical_normalizations.__dict__ action_normalizations = motor_params.action_normalizations.__dict__ params = motor_params.static_params.__dict__ - env = excenvs.make("PMSM-v0", LUT_motor_name=motor_name) + env = EnvironmentType.PMSM.make(motor_variant=motor_variant) for key, value in params.items(): env_value = getattr(env.env_properties.static_params, key) if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray): @@ -92,11 +93,10 @@ def test_custom_initialization(): "l_d": 0.37e-3, "l_q": 1.2e-3, "psi_p": 65.6e-3, - "u_dc":400, + "u_dc": 400, "deadtime": 1, } - env = excenvs.make( - "PMSM-v0", + env = EnvironmentType.PMSM.make( batch_size=batch_size, static_params=params, physical_normalizations=physical_normalizations, @@ -150,9 +150,10 @@ def test_custom_initialization(): def test_step_results(): data_dir = os.path.join(Path(__file__).parent, "data") file_path = os.path.join(data_dir, "sim_properties.json") - loaded_params,loaded_action_normalizations,loaded_physical_normalizations,loaded_tau=load_sim_properties_from_json(file_path) - env = excenvs.make( - "PMSM-v0", + loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = ( + load_sim_properties_from_json(file_path) + ) + env = EnvironmentType.PMSM.make( tau=loaded_tau, solver=diffrax.Euler(), static_params=loaded_params, @@ -170,4 +171,4 @@ def test_step_results(): obs, state = env.step(state, action, env.env_properties) generated_observations.append(obs) generated_observations = jnp.array(generated_observations) - assert jnp.allclose(generated_observations, stored_observations, 1e-8 ), "Step function generates different data" + assert jnp.allclose(generated_observations, stored_observations, 1e-8), "Step function generates different data" diff --git a/tests/envs/test_core_functions.py b/tests/envs/test_core_functions.py index 3394094..f622374 100644 --- a/tests/envs/test_core_functions.py +++ b/tests/envs/test_core_functions.py @@ -6,26 +6,26 @@ import diffrax from jax.tree_util import tree_flatten, tree_unflatten, tree_structure - +from exciting_environments import EnvironmentType jax.config.update("jax_platform_name", "cpu") jax.config.update("jax_enable_x64", True) -env_ids = ["Pendulum-v0", "MassSpringDamper-v0", "CartPole-v0", "FluidTank-v0", "PMSM-v0", "Acrobot-v0"] -fully_observable_env_ids = env_ids +envs_to_test = list(EnvironmentType) +fully_observable_envs = envs_to_test -@pytest.mark.parametrize("env_id", env_ids) +@pytest.mark.parametrize("env_type", envs_to_test) @pytest.mark.parametrize("tau", [1e-4, 1e-5]) -def test_tau(env_id, tau): - env = excenvs.make(env_id, tau=tau) +def test_tau(env_type, tau): + env = env_type.make(tau=tau) assert env.tau == tau -@pytest.mark.parametrize("env_id", env_ids) -def test_reset(env_id): +@pytest.mark.parametrize("env_type", envs_to_test) +def test_reset(env_type): batch_size = 4 - env = excenvs.make(env_id, batch_size=batch_size) + env = env_type.make(batch_size=batch_size) key = jax.random.PRNGKey(seed=1234) keys = jax.random.split(key, num=batch_size) @@ -52,10 +52,10 @@ def test_reset(env_id): assert type(state) == env.State, f"Default vmap_reset returns different state type." -@pytest.mark.parametrize("env_id", fully_observable_env_ids) -def test_gen_observation_gen_state(env_id): +@pytest.mark.parametrize("env_type", fully_observable_envs) +def test_gen_observation_gen_state(env_type): batch_size = 4 - env = excenvs.make(env_id, batch_size=batch_size) + env = env_type.make(batch_size=batch_size) # single obs, state = env.reset(env.env_properties) @@ -77,10 +77,10 @@ def test_gen_observation_gen_state(env_id): ) -@pytest.mark.parametrize("env_id", env_ids) -def test_step(env_id): +@pytest.mark.parametrize("env_type", envs_to_test) +def test_step(env_type): batch_size = 4 - env = excenvs.make(env_id, batch_size=batch_size) + env = env_type.make(batch_size=batch_size) # single init_obs, state = env.reset(env.env_properties) init_state_struct = tree_structure(state) @@ -100,11 +100,11 @@ def test_step(env_id): assert init_state_struct == tree_structure(state), "State changes structure during vmapped simulation steps." -@pytest.mark.parametrize("env_id", env_ids) -def test_simulate_ahead(env_id): +@pytest.mark.parametrize("env_type", envs_to_test) +def test_simulate_ahead(env_type): sim_steps = 10 batch_size = 4 - env = excenvs.make(env_id, batch_size=batch_size) + env = env_type.make(batch_size=batch_size) # single obs, init_state = env.reset(env.env_properties) acts = jnp.ones((sim_steps, env.action_dim)) @@ -131,11 +131,11 @@ def test_simulate_ahead(env_id): ), "State changes structure during vmapped simulate ahead." -@pytest.mark.parametrize("env_id", env_ids) -def test_similarity_step_sim_ahead_results(env_id): +@pytest.mark.parametrize("env_type", envs_to_test) +def test_similarity_step_sim_ahead_results(env_type): sim_steps = 10 batch_size = 4 - env = excenvs.make(env_id, batch_size=batch_size, solver=diffrax.Euler()) + env = env_type.make(batch_size=batch_size, solver=diffrax.Euler()) # single obs, state = env.reset(env.env_properties) diff --git a/tests/test_gym_wrapper.py b/tests/test_gym_wrapper.py index 2cc7547..32feb3c 100644 --- a/tests/test_gym_wrapper.py +++ b/tests/test_gym_wrapper.py @@ -11,14 +11,15 @@ jax.config.update("jax_platform_name", "cpu") jax.config.update("jax_enable_x64", True) +from exciting_environments import EnvironmentType -env_ids = ["Pendulum-v0", "MassSpringDamper-v0", "CartPole-v0", "FluidTank-v0", "PMSM-v0"] +envs_to_test = list(EnvironmentType) -@pytest.mark.parametrize("env_id", env_ids) -def test_step_returns_correct_outputs(env_id): +@pytest.mark.parametrize("env_type", envs_to_test) +def test_step_returns_correct_outputs(env_type): """Ensure step function returns outputs of expected type and shape.""" - env = excenvs.make(env_id, batch_size=4) + env = env_type.make(batch_size=4) gym_env = excenvs.GymWrapper(env=env) action = jnp.ones((env.batch_size, env.action_dim)) @@ -35,9 +36,9 @@ def test_step_returns_correct_outputs(env_id): assert terminated.shape == (4, 1), "Unexpected terminated shape" -@pytest.mark.parametrize("env_id", env_ids) -def test_gym_wrapper_ref_generation(env_id): - env = excenvs.make(env_id, batch_size=4) +@pytest.mark.parametrize("env_type", envs_to_test) +def test_gym_wrapper_ref_generation(env_type): + env = env_type.make(batch_size=4) gym_env = excenvs.GymWrapper(env=env) rng_env = jax.vmap(jax.random.PRNGKey)(jnp.array([0, 1, 2, 3])) rng_ref = jax.vmap(jax.random.PRNGKey)(jnp.array([0, 1, 2, 3]))