HomeSample Page

Sample Page Title


On this tutorial, we construct a safety-critical reinforcement studying pipeline that learns solely from mounted, offline information fairly than stay exploration. We design a customized surroundings, generate a conduct dataset from a constrained coverage, after which prepare each a Habits Cloning baseline and a Conservative Q-Studying agent utilizing d3rlpy. By structuring the workflow round offline datasets, cautious analysis, and conservative studying targets, we reveal how sturdy decision-making insurance policies could be educated in settings the place unsafe exploration is just not an choice. Try the FULL CODES right here.

!pip -q set up -U "d3rlpy" "gymnasium" "numpy" "torch" "matplotlib" "scikit-learn"


import os
import time
import random
import examine
import numpy as np
import matplotlib.pyplot as plt


import gymnasium as fitness center
from gymnasium import areas


import torch
import d3rlpy




SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)




def pick_device():
   if torch.cuda.is_available():
       return "cuda:0"
   return "cpu"




DEVICE = pick_device()
print("d3rlpy:", getattr(d3rlpy, "__version__", "unknown"), "| torch:", torch.__version__, "| machine:", DEVICE)




def make_config(cls, **kwargs):
   sig = examine.signature(cls.__init__)
   allowed = set(sig.parameters.keys())
   allowed.discard("self")
   filtered = {okay: v for okay, v in kwargs.objects() if okay in allowed}
   return cls(**filtered)

We arrange the surroundings by putting in dependencies, importing libraries, and fixing random seeds for reproducibility. We detect and configure the computation machine to make sure constant execution throughout methods. We additionally outline a utility to create configuration objects safely throughout completely different d3rlpy variations. Try the FULL CODES right here.

class SafetyCriticalGridWorld(fitness center.Env):
   metadata = {"render_modes": []}


   def __init__(
       self,
       dimension=15,
       max_steps=80,
       hazard_coords=None,
       begin=(0, 0),
       objective=None,
       slip_prob=0.05,
       seed=0,
   ):
       tremendous().__init__()
       self.dimension = int(dimension)
       self.max_steps = int(max_steps)
       self.begin = tuple(begin)
       self.objective = tuple(objective) if objective is just not None else (self.dimension - 1, self.dimension - 1)
       self.slip_prob = float(slip_prob)


       if hazard_coords is None:
           hz = set()
           rng = np.random.default_rng(seed)
           for _ in vary(max(1, self.dimension // 2)):
               x = rng.integers(2, self.dimension - 2)
               y = rng.integers(2, self.dimension - 2)
               hz.add((int(x), int(y)))
           self.hazards = hz
       else:
           self.hazards = set(tuple(x) for x in hazard_coords)


       self.action_space = areas.Discrete(4)
       self.observation_space = areas.Field(low=0.0, excessive=float(self.dimension - 1), form=(2,), dtype=np.float32)


       self._rng = np.random.default_rng(seed)
       self._pos = None
       self._t = 0


   def reset(self, *, seed=None, choices=None):
       if seed is just not None:
           self._rng = np.random.default_rng(seed)
       self._pos = [int(self.start[0]), int(self.begin[1])]
       self._t = 0
       obs = np.array(self._pos, dtype=np.float32)
       return obs, {}


   def _clip(self):
       self._pos[0] = int(np.clip(self._pos[0], 0, self.dimension - 1))
       self._pos[1] = int(np.clip(self._pos[1], 0, self.dimension - 1))


   def step(self, motion):
       self._t += 1


       a = int(motion)
       if self._rng.random() < self.slip_prob:
           a = int(self._rng.integers(0, 4))


       if a == 0:
           self._pos[1] += 1
       elif a == 1:
           self._pos[0] += 1
       elif a == 2:
           self._pos[1] -= 1
       elif a == 3:
           self._pos[0] -= 1


       self._clip()


       x, y = int(self._pos[0]), int(self._pos[1])
       terminated = False
       truncated = self._t >= self.max_steps


       reward = -1.0


       if (x, y) in self.hazards:
           reward = -100.0
           terminated = True


       if (x, y) == self.objective:
           reward = +50.0
           terminated = True


       obs = np.array([x, y], dtype=np.float32)
       return obs, float(reward), terminated, truncated, {}

We outline a safety-critical GridWorld surroundings with hazards, terminal states, and stochastic transitions. We encode penalties for unsafe states and rewards for profitable activity completion. We make sure the surroundings strictly controls dynamics to mirror real-world security constraints. Try the FULL CODES right here.

def safe_behavior_policy(obs, env: SafetyCriticalGridWorld, epsilon=0.15):
   x, y = int(obs[0]), int(obs[1])
   gx, gy = env.objective


   most well-liked = []
   if gx > x:
       most well-liked.append(1)
   elif gx < x:
       most well-liked.append(3)
   if gy > y:
       most well-liked.append(0)
   elif gy < y:
       most well-liked.append(2)


   if len(most well-liked) == 0:
       most well-liked = [int(env._rng.integers(0, 4))]


   if env._rng.random() < epsilon:
       return int(env._rng.integers(0, 4))


   candidates = []
   for a in most well-liked:
       nx, ny = x, y
       if a == 0:
           ny += 1
       elif a == 1:
           nx += 1
       elif a == 2:
           ny -= 1
       elif a == 3:
           nx -= 1
       nx = int(np.clip(nx, 0, env.dimension - 1))
       ny = int(np.clip(ny, 0, env.dimension - 1))
       if (nx, ny) not in env.hazards:
           candidates.append(a)


   if len(candidates) == 0:
       return most well-liked[0]
   return int(random.alternative(candidates))




def generate_offline_episodes(env, n_episodes=400, epsilon=0.20, seed=0):
   episodes = []
   for i in vary(n_episodes):
       obs, _ = env.reset(seed=int(seed + i))
       obs_list = []
       act_list = []
       rew_list = []
       done_list = []


       carried out = False
       whereas not carried out:
           a = safe_behavior_policy(obs, env, epsilon=epsilon)
           nxt, r, terminated, truncated, _ = env.step(a)
           carried out = bool(terminated or truncated)


           obs_list.append(np.array(obs, dtype=np.float32))
           act_list.append(np.array([a], dtype=np.int64))
           rew_list.append(np.array([r], dtype=np.float32))
           done_list.append(np.array([1.0 if done else 0.0], dtype=np.float32))


           obs = nxt


       episodes.append(
           {
               "observations": np.stack(obs_list, axis=0),
               "actions": np.stack(act_list, axis=0),
               "rewards": np.stack(rew_list, axis=0),
               "terminals": np.stack(done_list, axis=0),
           }
       )
   return episodes




def build_mdpdataset(episodes):
   obs = np.concatenate([ep["observations"] for ep in episodes], axis=0).astype(np.float32)
   acts = np.concatenate([ep["actions"] for ep in episodes], axis=0).astype(np.int64)
   rews = np.concatenate([ep["rewards"] for ep in episodes], axis=0).astype(np.float32)
   phrases = np.concatenate([ep["terminals"] for ep in episodes], axis=0).astype(np.float32)


   if hasattr(d3rlpy, "dataset") and hasattr(d3rlpy.dataset, "MDPDataset"):
       return d3rlpy.dataset.MDPDataset(observations=obs, actions=acts, rewards=rews, terminals=phrases)


   increase RuntimeError("d3rlpy.dataset.MDPDataset not discovered. Improve d3rlpy.")

We design a constrained conduct coverage that generates offline information with out dangerous exploration. We roll out this coverage to gather trajectories and construction them into episodes. We then convert these episodes right into a format suitable with d3rlpy’s offline studying APIs. Try the FULL CODES right here.

def _get_episodes_from_dataset(dataset):
   if hasattr(dataset, "episodes") and dataset.episodes is just not None:
       return dataset.episodes
   if hasattr(dataset, "get_episodes"):
       return dataset.get_episodes()
   increase AttributeError("Couldn't discover episodes in dataset (d3rlpy model mismatch).")




def _iter_all_observations(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       if obs is None:
           proceed
       for o in obs:
           yield o




def _iter_all_transitions(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       acts = getattr(ep, "actions", None)
       rews = getattr(ep, "rewards", None)
       if obs is None or acts is None:
           proceed
       n = min(len(obs), len(acts))
       for i in vary(n):
           o = obs[i]
           a = acts[i]
           r = rews[i] if rews is just not None and that i < len(rews) else None
           yield o, a, r




def visualize_dataset(dataset, env, title="Offline Dataset"):
   state_visits = np.zeros((env.dimension, env.dimension), dtype=np.float32)
   for obs in _iter_all_observations(dataset):
       x, y = int(obs[0]), int(obs[1])
       x = int(np.clip(x, 0, env.dimension - 1))
       y = int(np.clip(y, 0, env.dimension - 1))
       state_visits[y, x] += 1


   plt.determine(figsize=(6, 5))
   plt.imshow(state_visits, origin="decrease")
   plt.colorbar(label="Visits")
   plt.scatter([env.start[0]], [env.start[1]], marker="o", label="begin")
   plt.scatter([env.goal[0]], [env.goal[1]], marker="*", label="objective")
   if len(env.hazards) > 0:
       hz = np.array(listing(env.hazards), dtype=np.int32)
       plt.scatter(hz[:, 0], hz[:, 1], marker="x", label="hazards")
   plt.title(f"{title} — State visitation")
   plt.xlabel("x")
   plt.ylabel("y")
   plt.legend()
   plt.present()


   rewards = []
   for _, _, r in _iter_all_transitions(dataset):
       if r is just not None:
           rewards.append(float(r))
   if len(rewards) > 0:
       plt.determine(figsize=(6, 4))
       plt.hist(rewards, bins=60)
       plt.title(f"{title} — Reward distribution")
       plt.xlabel("reward")
       plt.ylabel("depend")
       plt.present()

We implement dataset utilities that accurately iterate by means of episodes fairly than assuming flat arrays. We visualize state visitation to know protection and information bias within the offline dataset. We additionally analyze reward distributions to examine the training sign accessible to the agent. Try the FULL CODES right here.

def rollout_eval(env, algo, n_episodes=25, seed=0):
   returns = []
   lengths = []
   hazard_hits = 0
   goal_hits = 0


   for i in vary(n_episodes):
       obs, _ = env.reset(seed=seed + i)
       carried out = False
       complete = 0.0
       steps = 0
       whereas not carried out:
           a = int(algo.predict(np.asarray(obs, dtype=np.float32)[None, ...])[0])
           obs, r, terminated, truncated, _ = env.step(a)
           complete += float(r)
           steps += 1
           carried out = bool(terminated or truncated)
           if terminated:
               x, y = int(obs[0]), int(obs[1])
               if (x, y) in env.hazards:
                   hazard_hits += 1
               if (x, y) == env.objective:
                   goal_hits += 1


       returns.append(complete)
       lengths.append(steps)


   return {
       "return_mean": float(np.imply(returns)),
       "return_std": float(np.std(returns)),
       "len_mean": float(np.imply(lengths)),
       "hazard_rate": float(hazard_hits / max(1, n_episodes)),
       "goal_rate": float(goal_hits / max(1, n_episodes)),
       "returns": np.asarray(returns, dtype=np.float32),
   }




def action_mismatch_rate_vs_data(dataset, algo, sample_obs=7000, seed=0):
   rng = np.random.default_rng(seed)
   obs_all = []
   act_all = []
   for o, a, _ in _iter_all_transitions(dataset):
       obs_all.append(np.asarray(o, dtype=np.float32))
       act_all.append(int(np.asarray(a).reshape(-1)[0]))
       if len(obs_all) >= 80_000:
           break


   obs_all = np.stack(obs_all, axis=0)
   act_all = np.asarray(act_all, dtype=np.int64)


   idx = rng.alternative(len(obs_all), dimension=min(sample_obs, len(obs_all)), exchange=False)
   obs_probe = obs_all[idx]
   act_probe_data = act_all[idx]
   act_probe_pi = algo.predict(obs_probe).astype(np.int64)


   mismatch = (act_probe_pi != act_probe_data).astype(np.float32)
   return float(mismatch.imply())




def create_discrete_bc(machine):
   if hasattr(d3rlpy.algos, "DiscreteBCConfig"):
       cls = d3rlpy.algos.DiscreteBCConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           batch_size=256,
       )
       return cfg.create(machine=machine)
   if hasattr(d3rlpy.algos, "DiscreteBC"):
       return d3rlpy.algos.DiscreteBC()
   increase RuntimeError("DiscreteBC not accessible on this d3rlpy model.")




def create_discrete_cql(machine, conservative_weight=6.0):
   if hasattr(d3rlpy.algos, "DiscreteCQLConfig"):
       cls = d3rlpy.algos.DiscreteCQLConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           actor_learning_rate=3e-4,
           critic_learning_rate=3e-4,
           temp_learning_rate=3e-4,
           alpha_learning_rate=3e-4,
           batch_size=256,
           conservative_weight=float(conservative_weight),
           n_action_samples=10,
           rollout_interval=0,
       )
       return cfg.create(machine=machine)
   if hasattr(d3rlpy.algos, "DiscreteCQL"):
       algo = d3rlpy.algos.DiscreteCQL()
       if hasattr(algo, "conservative_weight"):
           strive:
               algo.conservative_weight = float(conservative_weight)
           besides Exception:
               go
       return algo
   increase RuntimeError("DiscreteCQL not accessible on this d3rlpy model.")

We outline managed analysis routines to measure coverage efficiency with out uncontrolled exploration. We compute returns and security metrics, together with hazard and objective charges. We additionally introduce a mismatch diagnostic to quantify how typically realized actions deviate from the dataset conduct. Try the FULL CODES right here.

def foremost():
   env = SafetyCriticalGridWorld(
       dimension=15,
       max_steps=80,
       slip_prob=0.05,
       seed=SEED,
   )


   raw_eps = generate_offline_episodes(env, n_episodes=500, epsilon=0.22, seed=SEED)
   dataset = build_mdpdataset(raw_eps)


   print("dataset constructed:", sort(dataset).__name__)
   visualize_dataset(dataset, env, title="Habits Dataset (Offline)")


   bc = create_discrete_bc(DEVICE)
   cql = create_discrete_cql(DEVICE, conservative_weight=6.0)


   print("nTraining Discrete BC (offline)...")
   t0 = time.time()
   bc.match(
       dataset,
       n_steps=25_000,
       n_steps_per_epoch=2_500,
       experiment_name="grid_bc_offline",
   )
   print("BC prepare sec:", spherical(time.time() - t0, 2))


   print("nTraining Discrete CQL (offline)...")
   t0 = time.time()
   cql.match(
       dataset,
       n_steps=80_000,
       n_steps_per_epoch=8_000,
       experiment_name="grid_cql_offline",
   )
   print("CQL prepare sec:", spherical(time.time() - t0, 2))


   print("nControlled on-line analysis (small variety of rollouts):")
   bc_metrics = rollout_eval(env, bc, n_episodes=30, seed=SEED + 1000)
   cql_metrics = rollout_eval(env, cql, n_episodes=30, seed=SEED + 2000)


   print("BC :", {okay: v for okay, v in bc_metrics.objects() if okay != "returns"})
   print("CQL:", {okay: v for okay, v in cql_metrics.objects() if okay != "returns"})


   print("nOOD-ish diagnostic (coverage motion mismatch vs information motion at identical states):")
   bc_mismatch = action_mismatch_rate_vs_data(dataset, bc, sample_obs=7000, seed=SEED + 1)
   cql_mismatch = action_mismatch_rate_vs_data(dataset, cql, sample_obs=7000, seed=SEED + 2)
   print("BC mismatch price :", bc_mismatch)
   print("CQL mismatch price:", cql_mismatch)


   plt.determine(figsize=(6, 4))
   labels = ["BC", "CQL"]
   means = [bc_metrics["return_mean"], cql_metrics["return_mean"]]
   stds = [bc_metrics["return_std"], cql_metrics["return_std"]]
   plt.bar(labels, means, yerr=stds)
   plt.ylabel("Return")
   plt.title("On-line Rollout Return (Managed)")
   plt.present()


   plt.determine(figsize=(6, 4))
   plt.plot(np.kind(bc_metrics["returns"]), label="BC")
   plt.plot(np.kind(cql_metrics["returns"]), label="CQL")
   plt.xlabel("Episode (sorted)")
   plt.ylabel("Return")
   plt.title("Return Distribution (Sorted)")
   plt.legend()
   plt.present()


   out_dir = "/content material/offline_rl_artifacts"
   os.makedirs(out_dir, exist_ok=True)
   bc_path = os.path.be a part of(out_dir, "grid_bc_policy.pt")
   cql_path = os.path.be a part of(out_dir, "grid_cql_policy.pt")


   if hasattr(bc, "save_policy"):
       bc.save_policy(bc_path)
       print("Saved BC coverage:", bc_path)
   if hasattr(cql, "save_policy"):
       cql.save_policy(cql_path)
       print("Saved CQL coverage:", cql_path)


   print("nDone.")




if __name__ == "__main__":
   foremost()

We prepare each Habits Cloning and Conservative Q-Studying brokers purely from offline information. We examine their efficiency utilizing managed rollouts and diagnostic metrics. We finalize the workflow by saving educated insurance policies and summarizing safety-aware studying outcomes.

In conclusion, we demonstrated that Conservative Q-Studying yields a extra dependable coverage than easy imitation when studying from historic information in safety-sensitive environments. By evaluating offline coaching outcomes, managed on-line evaluations, and action-distribution mismatches, we illustrated how conservatism helps scale back dangerous, out-of-distribution conduct. Total, we introduced an entire, reproducible offline RL workflow that we are able to lengthen to extra advanced domains akin to robotics, healthcare, or finance with out compromising security.


Try the FULL CODES right here. Additionally, be at liberty to comply with us on Twitter and don’t neglect to affix our 100k+ ML SubReddit and Subscribe to our Publication. Wait! are you on telegram? now you may be a part of us on telegram as effectively.


Related Articles

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Latest Articles