Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ A basic routine is as simple as:
```py
import jax.numpy as jnp
import exciting_environments as excenvs
from exciting_environments import EnvironmentType
from exciting_environments.utils import MinMaxNormalization

env = excenvs.make(
"Pendulum-v0",
env = EnvironmentType.PENDULUM.make(
batch_size=5,
action_normalizations={"torque": MinMaxNormalization(min=-15,max=15)},
tau=2e-2
)
)
obs, state = env.vmap_reset()

actions = jnp.linspace(start=-1, stop=1, num=1000)[None, :, None]
Expand Down Expand Up @@ -45,11 +45,11 @@ alternatively, simulate full trajectories:
```py
import jax.numpy as jnp
import exciting_environments as excenvs
from exciting_environments import EnvironmentType
from exciting_environments.utils import MinMaxNormalization
import diffrax

env = excenvs.make(
"Pendulum-v0",
env = EnvironmentType.PENDULUM.make(
solver=diffrax.Tsit5(),
batch_size=5,
action_normalizations={"torque": MinMaxNormalization(min=-15,max=15)},
Expand All @@ -76,4 +76,4 @@ which produces $5$ identical trajectories in parallel as well:
![](https://github.com/ExcitingSystems/exciting-environments/blob/main/fig/excenvs_pendulum_simulation_example_advanced.png?raw=true)

Note that in this case the Tsit5 ODE solver instead of the default explicit Euler is used.
All solvers used here are from the diffrax library (https://docs.kidger.site/diffrax/).
All solvers used here are from the diffrax library (https://docs.kidger.site/diffrax/).
2 changes: 1 addition & 1 deletion exciting_environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from .pendulum import Pendulum
from .fluid_tank import FluidTank
from .acrobot import Acrobot
from .registration import make
from .registration import EnvironmentType
from .gym_wrapper import GymWrapper
from .mujoco_wrapper import MujucoWrapper
11 changes: 6 additions & 5 deletions exciting_environments/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from functools import partial
import chex
from abc import ABC
from exciting_environments import spaces
from exciting_environments.registration import make
from exciting_environments import spaces, EnvironmentType

# from exciting_environments.registration import make


class GymWrapper(ABC):
Expand Down Expand Up @@ -58,9 +59,9 @@ def __init__(
self.generate_terminated = self.env.generate_terminated

@classmethod
def from_name(cls, env_id: str, **env_kwargs):
"""Creates GymWrapper with environment based on passed env_id."""
env = make(env_id, **env_kwargs)
def from_env(cls, env_type: EnvironmentType, **env_kwargs):
"""Creates GymWrapper with environment based on passed EnvironmentType."""
env = env_type.make(**env_kwargs)
return cls(env)

def step(self, action):
Expand Down
2 changes: 1 addition & 1 deletion exciting_environments/pmsm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .motor_parameters import default_params
from .motor_parameters import MotorVariant
from .pmsm_env import PMSM # , PMSM_Physical
32 changes: 14 additions & 18 deletions exciting_environments/pmsm/motor_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from exciting_environments.utils import MinMaxNormalization
from copy import deepcopy

from enum import Enum


@jdc.pytree_dataclass
class PhysicalNormalizations:
Expand Down Expand Up @@ -147,21 +149,15 @@ def default_soft_constraints(self, state, action_norm, env_properties):
)


def default_params(name):
"""
Returns default parameters for specified motor configurations.

Args:
name (str): Name of the motor ("BRUSA" or "SEW").

Returns:
MotorConfig: Configuration containing physical constraints, action constraints, static parameters, and LUT data.
"""
if name is None:
return deepcopy(DEFAULT)
elif name == "BRUSA":
return deepcopy(BRUSA)
elif name == "SEW":
return deepcopy(SEW)
else:
raise ValueError(f"Motor name {name} is not known.")
class MotorVariant(Enum):
DEFAULT = "DEFAULT"
BRUSA = "BRUSA"
SEW = "SEW"

def get_params(self):
if self is MotorVariant.BRUSA:
return deepcopy(BRUSA)
elif self is MotorVariant.SEW:
return deepcopy(SEW)
else:
return deepcopy(DEFAULT)
17 changes: 10 additions & 7 deletions exciting_environments/pmsm/pmsm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from copy import deepcopy

from exciting_environments import CoreEnvironment
from exciting_environments.pmsm import default_params
from exciting_environments.pmsm import MotorVariant


# only for alpha/beta -> abc
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self,
batch_size: int = 8,
saturated=False,
LUT_motor_name: str = None,
motor_variant: MotorVariant = MotorVariant.DEFAULT,
physical_normalizations: dict = None,
action_normalizations: dict = None,
soft_constraints: Callable = None,
Expand All @@ -130,7 +130,7 @@ def __init__(
Args:
batch_size (int): Number of parallel environment simulations. Default: 8
saturated (bool): Permanent magnet flux linkages and inductances are taken from LUT_motor_name specific LUTs. Default: False
LUT_motor_name (str): Sets physical_normalizations, action_normalizations, soft_constraints and static_params to default values for the passed motor name and stores associated LUTs for the possible saturated case. Needed if saturated==True.
motor_variant (MotorVariant): Sets physical_normalizations, action_normalizations, soft_constraints and static_params to default values for the passed motor variant and stores associated LUTs for the possible saturated case. Needed if saturated==True.
physical_normalizations (dict): min-max normalization values of the physical state of the environment.
u_d_buffer (MinMaxNormalization): Direct share of the delayed action due to system deadtime. Default: min=-2 * 400 / 3, max=2 * 400 / 3
u_q_buffer (MinMaxNormalization): Quadrature share of the delayed action due to system deadtime. Default: min=-2 * 400 / 3, max=2 * 400 / 3
Expand Down Expand Up @@ -161,8 +161,8 @@ def __init__(
self.tau = tau
self._solver = solver

if LUT_motor_name is not None:
motor_params = deepcopy(default_params(LUT_motor_name))
if motor_variant != MotorVariant.DEFAULT:
motor_params = motor_variant.get_params()
default_physical_normalizations = motor_params.physical_normalizations.__dict__
default_action_normalizations = motor_params.action_normalizations.__dict__
default_static_params = motor_params.static_params.__dict__
Expand All @@ -187,7 +187,10 @@ def __init__(

else:
if saturated:
raise Exception("LUT_motor_name is needed to load LUTs.")
raise ValueError(
f"MotorVariant '{motor_variant.value}' is not allowed for saturated LUTs. "
"Use a specific motor variant. DEFAULT is only valid for saturated=False."
)

saturated_quants = [
"L_dd",
Expand All @@ -198,7 +201,7 @@ def __init__(
"Psi_q",
]

motor_params = deepcopy(default_params(LUT_motor_name))
motor_params = motor_variant.get_params()
default_physical_normalizations = motor_params.physical_normalizations.__dict__
default_action_normalizations = motor_params.action_normalizations.__dict__
default_static_params = motor_params.static_params.__dict__
Expand Down
50 changes: 24 additions & 26 deletions exciting_environments/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,27 @@
PMSM,
Acrobot,
)


def make(env_id: str, **env_kwargs):
if env_id == "CartPole-v0":
env = CartPole(**env_kwargs)

elif env_id == "MassSpringDamper-v0":
env = MassSpringDamper(**env_kwargs)

elif env_id == "Pendulum-v0":
env = Pendulum(**env_kwargs)

elif env_id == "FluidTank-v0":
env = FluidTank(**env_kwargs)

elif env_id == "PMSM-v0":
env = PMSM(**env_kwargs)

elif env_id == "Acrobot-v0":
env = Acrobot(**env_kwargs)

else:
print(f"No existing environments got env_id ={env_id}")
env = None

return env
from enum import Enum


class EnvironmentType(Enum):
CART_POLE = "CartPole-v0"
MASS_SPRING_DAMPER = "MassSpringDamper-v0"
PENDULUM = "Pendulum-v0"
FLUID_TANK = "FluidTank-v0"
PMSM = "PMSM-v0"
ACROBOT = "Acrobot-v0"

def make(self, **env_kwargs):
env_map = {
EnvironmentType.CART_POLE: CartPole,
EnvironmentType.MASS_SPRING_DAMPER: MassSpringDamper,
EnvironmentType.PENDULUM: Pendulum,
EnvironmentType.FLUID_TANK: FluidTank,
EnvironmentType.PMSM: PMSM,
EnvironmentType.ACROBOT: Acrobot,
}
cls = env_map.get(self)
if cls is None:
raise ValueError(f"Unknown environment: {self}")
return cls(**env_kwargs)
10 changes: 5 additions & 5 deletions tests/envs/acrobot/test_acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import jax.numpy as jnp
import numpy as np
import diffrax
from exciting_environments import EnvironmentType
from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json
from pathlib import Path
import pickle
import os


jax.config.update("jax_enable_x64", True)


Expand All @@ -33,7 +35,7 @@ def test_default_initialization():
"omega_1": MinMaxNormalization(min=-10, max=10),
"omega_2": MinMaxNormalization(min=-10, max=10),
}
env = excenvs.make("Acrobot-v0", batch_size=batch_size)
env = EnvironmentType.ACROBOT.make(batch_size=batch_size)
for key, value in params.items():
env_value = getattr(env.env_properties.static_params, key)
if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray):
Expand Down Expand Up @@ -101,8 +103,7 @@ def test_custom_initialization():
"omega_1": MinMaxNormalization(min=-55, max=10),
"omega_2": MinMaxNormalization(min=-10, max=30),
}
env = excenvs.make(
"Acrobot-v0",
env = EnvironmentType.ACROBOT.make(
batch_size=batch_size,
static_params=params,
physical_normalizations=physical_normalizations,
Expand Down Expand Up @@ -159,8 +160,7 @@ def test_step_results():
loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = (
load_sim_properties_from_json(file_path)
)
env = excenvs.make(
"Acrobot-v0",
env = EnvironmentType.ACROBOT.make(
tau=loaded_tau,
solver=diffrax.Euler(),
static_params=loaded_params,
Expand Down
19 changes: 11 additions & 8 deletions tests/envs/cartpole/test_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import jax.numpy as jnp
import numpy as np
import diffrax
from exciting_environments.utils import MinMaxNormalization,load_sim_properties_from_json
from exciting_environments import EnvironmentType
from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json
from pathlib import Path
import pickle
import os


jax.config.update("jax_enable_x64", True)


def test_default_initialization():
"""Ensure default static parameters and normalizations are not changed by accident."""
batch_size = 4
Expand All @@ -29,7 +32,7 @@ def test_default_initialization():
"theta": MinMaxNormalization(min=-jnp.pi, max=jnp.pi),
"omega": MinMaxNormalization(min=-8, max=8),
}
env = excenvs.make("CartPole-v0", batch_size=batch_size)
env = EnvironmentType.CART_POLE.make(batch_size=batch_size)
for key, value in params.items():
env_value = getattr(env.env_properties.static_params, key)
if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray):
Expand Down Expand Up @@ -94,8 +97,7 @@ def test_static_parameters_initialization():
"m_c": jnp.repeat(1, batch_size),
"g": 35.81,
}
env = excenvs.make(
"CartPole-v0",
env = EnvironmentType.CART_POLE.make(
batch_size=batch_size,
static_params=params,
physical_normalizations=physical_normalizations,
Expand Down Expand Up @@ -150,9 +152,10 @@ def test_static_parameters_initialization():
def test_step_results():
data_dir = os.path.join(Path(__file__).parent, "data")
file_path = os.path.join(data_dir, "sim_properties.json")
loaded_params,loaded_action_normalizations,loaded_physical_normalizations,loaded_tau=load_sim_properties_from_json(file_path)
env = excenvs.make(
"CartPole-v0",
loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = (
load_sim_properties_from_json(file_path)
)
env = EnvironmentType.CART_POLE.make(
tau=loaded_tau,
solver=diffrax.Euler(),
static_params=loaded_params,
Expand All @@ -170,4 +173,4 @@ def test_step_results():
obs, state = env.step(state, action, env.env_properties)
generated_observations.append(obs)
generated_observations = jnp.array(generated_observations)
assert jnp.allclose(generated_observations, stored_observations, 1e-16), "Step function generates different data"
assert jnp.allclose(generated_observations, stored_observations, 1e-16), "Step function generates different data"
15 changes: 8 additions & 7 deletions tests/envs/fluid_tank/test_fluid_tank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import jax.numpy as jnp
import numpy as np
import diffrax
from exciting_environments.utils import MinMaxNormalization,load_sim_properties_from_json
from exciting_environments import EnvironmentType
from exciting_environments.utils import MinMaxNormalization, load_sim_properties_from_json
from pathlib import Path
import pickle
import os
Expand All @@ -18,7 +19,7 @@ def test_default_initialization():
params = {"base_area": jnp.pi, "orifice_area": jnp.pi * 0.1**2, "c_d": 0.6, "g": 9.81}
action_normalizations = {"inflow": MinMaxNormalization(min=0, max=0.2)}
physical_normalizations = {"height": MinMaxNormalization(min=0, max=3)}
env = excenvs.make("FluidTank-v0", batch_size=batch_size)
env = EnvironmentType.FLUID_TANK.make(batch_size=batch_size)
for key, value in params.items():
env_value = getattr(env.env_properties.static_params, key)
if isinstance(value, jnp.ndarray) or isinstance(env_value, jnp.ndarray):
Expand Down Expand Up @@ -71,8 +72,7 @@ def test_custom_initialization():
params = {"base_area": jnp.repeat(jnp.pi, batch_size), "orifice_area": jnp.pi * 0.1**2, "c_d": 0.6, "g": 9.81}
action_normalizations = {"inflow": MinMaxNormalization(min=jnp.repeat(0.02, batch_size), max=0.3)}
physical_normalizations = {"height": MinMaxNormalization(min=1, max=5)}
env = excenvs.make(
"FluidTank-v0",
env = EnvironmentType.FLUID_TANK.make(
batch_size=batch_size,
static_params=params,
physical_normalizations=physical_normalizations,
Expand Down Expand Up @@ -127,9 +127,10 @@ def test_custom_initialization():
def test_step_results():
data_dir = os.path.join(Path(__file__).parent, "data")
file_path = os.path.join(data_dir, "sim_properties.json")
loaded_params,loaded_action_normalizations,loaded_physical_normalizations,loaded_tau=load_sim_properties_from_json(file_path)
env = excenvs.make(
"FluidTank-v0",
loaded_params, loaded_action_normalizations, loaded_physical_normalizations, loaded_tau = (
load_sim_properties_from_json(file_path)
)
env = EnvironmentType.FLUID_TANK.make(
tau=loaded_tau,
solver=diffrax.Euler(),
static_params=loaded_params,
Expand Down
Loading
Loading