####################################################################### # Name: test_info_surfing.py # # - Runs robot in environment using Info Surfing Planner ####################################################################### import sys sys.modules['TRAINING'] = False # False = Inference Testing import copy import os import imageio import numpy as np import matplotlib.pyplot as plt from pathlib import Path from time import time from types import SimpleNamespace from skimage.transform import resize from taxabind_avs.satbind.kmeans_clustering import CombinedSilhouetteInertiaClusterer from .env import Env from .test_parameter import * OPPOSITE_ACTIONS = {1: 3, 2: 4, 3: 1, 4: 2, 5: 7, 6: 8, 7: 5, 8: 6} # color agentColor = (1, 0.2, 0.6) agentCommColor = (1, 0.6, 0.2) obstacleColor = (0., 0., 0.) targetNotFound = (0., 1., 0.) targetFound = (0.545, 0.27, 0.075) highestProbColor = (1., 0., 0.) highestUncertaintyColor = (0., 0., 1.) lowestProbColor = (1., 1., 1.) class ISEnv: """Custom Environment that follows gym interface""" metadata = {'render.modes': ['human']} def __init__(self, global_step=0, state=None, shape=(24, 24), numAgents=8, observationSize=11, sensorSize=1, diag=False, save_image=False, clip_seg_tta=None): self.global_step = global_step self.infoMap = None self.targetMap = None self.agents = [] self.targets = [] self.numAgents = numAgents self.found_target = [] self.shape = shape self.observationSize = observationSize self.sensorSize = sensorSize self.diag = diag self.communicateCircle = 11 self.distribs = [] self.mask = None self.finished = False self.action_vects = [[-1., 0.], [0., 1.], [1., 0], [0., -1.]] if not diag else [[-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] self.actionlist = [] self.IS_step = 0 self.save_image = save_image self.clip_seg_tta = clip_seg_tta self.perf_metrics = dict() self.steps_to_first_tgt = None self.steps_to_mid_tgt = None self.steps_to_last_tgt = None self.targets_found_on_path = [] self.step_since_tta = 0 self.IS_frame_files = [] self.bad_mask_init = False # define env self.env = Env(map_index=self.global_step, n_agent=numAgents, k_size=K_SIZE, plot=save_image, test=True) # Overwrite state if self.clip_seg_tta is not None: self.clip_seg_tta.reset(sample_idx=self.global_step) # Override target positions in env self.env.target_positions = [(pose[1], pose[0]) for pose in self.clip_seg_tta.target_positions] # Override segmentation mask if not USE_CLIP_PREDS and OVERRIDE_MASK_DIR != "": score_mask_path = os.path.join(OVERRIDE_MASK_DIR, self.clip_seg_tta.gt_mask_name) print("score_mask_path: ", score_mask_path) if os.path.exists(score_mask_path): self.env.segmentation_mask = self.env.import_segmentation_mask(score_mask_path) self.env.begin(self.env.map_start_position) else: print(f"\n\n{RED}ERROR: Trying to override, but score mask not found at path:{NC} ", score_mask_path) self.bad_mask_init = True # Save clustered embeds from sat encoder if USE_CLIP_PREDS: self.kmeans_clusterer = CombinedSilhouetteInertiaClusterer( k_min=1, k_max=8, k_avg_max=4, silhouette_threshold=0.15, relative_threshold=0.15, random_state=0, min_patch_size=5, n_smooth_iter=2, ignore_label=-1, plot=self.save_image, gifs_dir = GIFS_PATH ) # Generate kmeans clusters self.kmeans_sat_embeds_clusters = self.kmeans_clusterer.fit_predict( patch_embeds=self.clip_seg_tta.patch_embeds, map_shape=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]), ) if EXECUTE_TTA: print("Will execute TTA...") IS_info_map = copy.deepcopy(self.env.segmentation_info_mask) IS_agent_loc = copy.deepcopy(self.env.start_positions) IS_target_loc = copy.deepcopy(self.env.target_positions) state=[IS_info_map, IS_agent_loc, IS_target_loc] self.setWorld(state) def init_render(self): """ Call this once (e.g., in __init__ or just before the scenario loop) to initialize storage for agent paths and turn interactive plotting on. """ # Keep track of each agent's trajectory self.trajectories = [[] for _ in range(self.numAgents)] self.trajectories_upscaled = [[] for _ in range(self.numAgents)] # Turn on interactive mode so we can update the same figure repeatedly plt.ion() plt.figure(figsize=(6,6)) plt.title("Information Map with Agents, Targets, and Sensor Ranges") def record_positions(self): """ Call this after all agents have moved in a step (or whenever you want to update the trajectory). It appends the current positions of each agent to `self.trajectories`. """ for idx, agent in enumerate(self.agents): self.trajectories[idx].append((agent.row, agent.col)) self.trajectories_upscaled[idx].append(self.env.graph_generator.grid_coords[agent.row, agent.col]) def render(self, episode_num, step_num): """ Renders the current state in a single matplotlib plot. Ensures consistent image size for GIF generation. """ # Completely reset the figure to avoid leftover state plt.close('all') fig = plt.figure(figsize=(6.4, 4.8), dpi=100) ax = fig.add_subplot(111) # Plot the information map ax.imshow(self.infoMap, origin='lower', cmap='gray') # Show agent positions and their trajectories for idx, agent in enumerate(self.agents): positions = self.trajectories[idx] if len(positions) > 1: rows = [p[0] for p in positions] cols = [p[1] for p in positions] ax.plot(cols, rows, linewidth=1) ax.scatter(agent.col, agent.row, marker='o', s=50) # Plot target locations for t in self.targets: color = 'green' if np.isnan(t.time_found) else 'red' ax.scatter(t.col, t.row, marker='x', s=100, color=color) # Title and axis formatting ax.set_title(f"Step: {self.IS_step}") ax.invert_yaxis() # Create output folder if it doesn't exist if not os.path.exists(GIFS_PATH): os.makedirs(GIFS_PATH) # Save the frame with consistent canvas frame_path = f'{GIFS_PATH}/IS_{episode_num}_{step_num}.png' plt.savefig(frame_path, bbox_inches='tight', pad_inches=0.1) self.IS_frame_files.append(frame_path) # Cleanup plt.close(fig) def setWorld(self, state=None): """ 1. empty all the element 2. create the new episode """ if state is not None: self.infoMap = copy.deepcopy(state[0].reshape(self.shape).T) agents = [] self.numAgents = len(state[1]) for a in range(1, self.numAgents + 1): abs_pos = state[1].pop(0) abs_pos = np.array(abs_pos) row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(np.array(abs_pos)) agents.append(Agent(ID=a, row=row, col=col, sensorSize=self.sensorSize, infoMap=np.copy(self.infoMap), uncertaintyMap=np.copy(self.infoMap), shape=self.shape, numAgents=self.numAgents)) self.agents = agents targets, n_targets = [], 1 for t in range(len(state[2])): abs_pos = state[2].pop(0) abs_pos = np.array(abs_pos) row, col = self.env.graph_generator.find_closest_index_from_grid_coords_2d(abs_pos) targets.append(Target(ID=n_targets, row=row, col=col, time_found=np.nan)) n_targets = n_targets + 1 self.targets = targets def extractObservation(self, agent): """ Extract observations from information map """ transform_row = self.observationSize // 2 - agent.row transform_col = self.observationSize // 2 - agent.col observation_layers = np.zeros((1, self.observationSize, self.observationSize)) min_row = max((agent.row - self.observationSize // 2), 0) max_row = min((agent.row + self.observationSize // 2 + 1), self.shape[0]) min_col = max((agent.col - self.observationSize // 2), 0) max_col = min((agent.col + self.observationSize // 2 + 1), self.shape[1]) observation = np.full((self.observationSize, self.observationSize), 0.) infoMap = np.full((self.observationSize, self.observationSize), 0.) densityMap = np.full((self.observationSize, self.observationSize), 0.) infoMap[(min_row + transform_row):(max_row + transform_row), (min_col + transform_col):(max_col + transform_col)] = self.infoMap[ min_row:max_row, min_col:max_col] observation_layers[0] = infoMap return observation_layers def listNextValidActions(self, agent_id, prev_action=0): """ No movement: 0 North (-1,0): 1 East (0,1): 2 South (1,0): 3 West (0,-1): 4 """ available_actions = [0] agent = self.agents[agent_id - 1] MOVES = [(-1, 0), (0, 1), (1, 0), (0, -1), (-1, -1), (-1, 1), (1, 1), (1, -1)] size = 4 + self.diag * 4 for action in range(size): out_of_bounds = agent.row + MOVES[action][0] >= self.shape[0] \ or agent.row + MOVES[action][0] < 0\ or agent.col + MOVES[action][1] >= self.shape[1] \ or agent.col + MOVES[action][1] < 0 if (not out_of_bounds) and not (prev_action == OPPOSITE_ACTIONS[action + 1]): available_actions.append(action + 1) return np.array(available_actions) def executeAction(self, agentID, action, timeStep): """ No movement: 0 North (-1,0): 1 East (0,1): 2 South (1,0): 3 West (0,-1): 4 LeftUp (-1,-1) : 5 RightUP (-1,1) :6 RightDown (1,1) :7 RightLeft (1,-1) :8 """ agent = self.agents[agentID - 1] origLoc = agent.getLocation() if (action >= 1) and (action <= 8): agent.move(action) row, col = agent.getLocation() # If the move is not valid, roll it back if (row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1]): self.updateInfoCheckTarget(agentID, timeStep, origLoc) return 0 elif action == 0: self.updateInfoCheckTarget(agentID, timeStep, origLoc) return 0 else: print("INVALID ACTION: {}".format(action)) sys.exit() newLoc = agent.getLocation() self.updateInfoCheckTarget(agentID, timeStep, origLoc) return action def updateInfoCheckTarget(self, agentID, timeStep, origLoc): """ update the self.infoMap and check whether the agent has found a target """ agent = self.agents[agentID - 1] transform_row = self.sensorSize // 2 - agent.row transform_col = self.sensorSize // 2 - agent.col min_row = max((agent.row - self.sensorSize // 2), 0) max_row = min((agent.row + self.sensorSize // 2 + 1), self.shape[0]) min_col = max((agent.col - self.sensorSize // 2), 0) max_col = min((agent.col + self.sensorSize // 2 + 1), self.shape[1]) for t in self.targets: if (t.row == agent.row) and (t.col == agent.col): t.updateFound(timeStep) self.found_target.append(t) t.status = True self.infoMap[min_row:max_row, min_col:max_col] *= 0.05 def updateInfoEntireTrajectory(self, agentID): """ update the self.infoMap and check whether the agent has found a target """ traj = self.trajectories[agentID - 1] for (row,col) in traj: min_row = max((row - self.sensorSize // 2), 0) max_row = min((row + self.sensorSize // 2 + 1), self.shape[0]) min_col = max((col - self.sensorSize // 2), 0) max_col = min((col + self.sensorSize // 2 + 1), self.shape[1]) self.infoMap[min_row:max_row, min_col:max_col] *= 0.05 # Execute one time step within the environment def step(self, agentID, action, timeStep): """ the agents execute the actions No movement: 0 North (-1,0): 1 East (0,1): 2 South (1,0): 3 West (0,-1): 4 """ assert (agentID > 0) self.executeAction(agentID, action, timeStep) def observe(self, agentID): assert (agentID > 0) vectorObs = self.extractObservation(self.agents[agentID - 1]) return [vectorObs] def check_finish(self): if TERMINATE_ON_TGTS_FOUND: found_status = [t.time_found for t in self.targets] d = False if np.isnan(found_status).sum() == 0: d = True return d else: return False def gradVec(self, observation, agent): a = observation[0] # Make info & unc cells with low value as 0 a[a < 0.0002] = 0.0 # Center square from 11x11 a_11x11 = a[4:7, 4:7] m_11x11 = np.array((a_11x11)) # Center square from 9x9 a_9x9 = self.pooling(a, (3, 3), stride=(1, 1), method='max', pad=False) a_9x9 = a_9x9[3:6, 3:6] m_9x9 = np.array((a_9x9)) # Center square from 6x6 a_6x6 = self.pooling(a, (6, 6), stride=(1, 1), method='max', pad=False) a_6x6 = a_6x6[1:4, 1:4] m_6x6 = np.array((a_6x6)) # Center square from 3x3 a_3x3 = self.pooling(a, (5, 5), stride=(3, 3), method='max', pad=False) m_3x3 = np.array((a_3x3)) # Merging multiScales with weights m = m_3x3 * 0.25 + m_6x6 * 0.25 + m_9x9 * 0.25 + m_11x11 * 0.25 a = m adx, ady = np.gradient(a) den = np.linalg.norm(np.array([adx[1, 1], ady[1, 1]])) if (den != 0) and (not np.isnan(den)): infovec = np.array([adx[1, 1], ady[1, 1]]) / den else: infovec = 0 agentvec = [] if len(agentvec) == 0: den = np.linalg.norm(infovec) if (den != 0) and (not np.isnan(den)): direction = infovec / den else: direction = self.action_vects[np.random.randint(4 + self.diag * 4)] else: den = np.linalg.norm(np.mean(agentvec, 0)) if (den != 0) and (not np.isnan(den)): agentvec = np.mean(agentvec, 0) / den else: agentvec = 0 den = np.linalg.norm(0.6 * infovec + 0.4 * agentvec) if (den != 0) and (not np.isnan(den)): direction = (0.6 * infovec + 0.4 * agentvec) / den else: direction = self.action_vects[np.random.randint(4 + self.diag * 4)] action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] actionid = np.argmax([np.dot(direction, a) for a in action_vec]) actionid = self.best_valid_action(actionid, agent, direction) return actionid def best_valid_action(self, actionid, agent, direction): if len(self.actionlist) > 1: if self.action_invalid(actionid, agent): action_vec = [[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]] if not self.diag else [[0., 0.], [-1., 0.], [0., 1.], [1., 0], [0., -1.], [-0.707, -0.707], [-0.707, 0.707], [0.707, 0.707], [0.707, -0.707]] actionid = np.array([np.dot(direction, a) for a in action_vec]) actionid = actionid.argsort() pi = 3 + self.diag*4 while self.action_invalid(actionid[pi], agent) and pi >= 0: pi -= 1 if pi == -1: return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]] elif actionid[pi] == 0: return OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]] else: return actionid[pi] return actionid def action_invalid(self, action, agent): # Going back to the previous cell is disabled if action == OPPOSITE_ACTIONS[self.actionlist[self.IS_step - 1][agent - 1]]: return True # Move N,E,S,W if (action >= 1) and (action <= 8): agent = self.agents[agent - 1] agent.move(action) row, col = agent.getLocation() # If the move is not valid, roll it back if ((row < 0) or (col < 0) or (row >= self.shape[0]) or (col >= self.shape[1])): agent.reverseMove(action) return True agent.reverseMove(action) return False return False def step_all_parallel(self): actions = [] reward = 0 # Decide actions for each agent for agent_id in range(1, self.numAgents + 1): o = self.observe(agent_id) actions.append(self.gradVec(o[0], agent_id)) self.actionlist.append(actions) # Execute those actions for agent_id in range(1, self.numAgents + 1): self.step(agent_id, actions[agent_id - 1], self.IS_step) # Record for visualization self.record_positions() def is_scenario(self, max_step=512, episode_number=0): # Return all metrics as None if faulty mask init if self.bad_mask_init: self.perf_metrics['tax'] = None self.perf_metrics['travel_dist'] = None self.perf_metrics['travel_steps'] = None self.perf_metrics['steps_to_first_tgt'] = None self.perf_metrics['steps_to_mid_tgt'] = None self.perf_metrics['steps_to_last_tgt'] = None self.perf_metrics['explored_rate'] = None self.perf_metrics['targets_found'] = None self.perf_metrics['targets_total'] = None self.perf_metrics['kmeans_k'] = None self.perf_metrics['tgts_gt_score'] = None self.perf_metrics['clip_inference_time'] = None self.perf_metrics['tta_time'] = None self.perf_metrics['success_rate'] = None return eps_start = time() self.IS_step = 0 self.finished = False reward = 0 # Initialize the rendering just once before the loop self.init_render() self.record_positions() # Initial Setup if LOAD_AVS_BENCH and USE_CLIP_PREDS: if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) self.infoMap = copy.deepcopy(heatmap) print("Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) else: self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) self.infoMap = copy.deepcopy(self.clip_seg_tta.heatmap) self.targets_found_on_path.append(self.env.num_new_targets_found) while self.IS_step < max_step and not self.check_finish(): self.step_all_parallel() self.IS_step += 1 # Render after each step if self.save_image: self.render(episode_num=self.global_step, step_num=self.IS_step) # Update in env next_position_list = [self.trajectories_upscaled[i][-1] for i, agent in enumerate(self.agents)] dist_list = [0 for _ in range(self.numAgents)] travel_dist_list = [self.compute_travel_distance(traj) for traj in self.trajectories] self.env.multi_robot_step(next_position_list, dist_list, travel_dist_list) self.targets_found_on_path.append(self.env.num_new_targets_found) # TTA Update via Poisson Test (with KMeans clustering stats) robot_id = 0 # Assume 1 agent for now robot_traj = self.trajectories[robot_id] if LOAD_AVS_BENCH and USE_CLIP_PREDS and EXECUTE_TTA: flat_traj_coords = [robot_traj[i][1] * self.shape[0] + robot_traj[i][0] for i in range(len(robot_traj))] robot = SimpleNamespace( trajectory_coords=flat_traj_coords, targets_found_on_path=self.targets_found_on_path ) self.poisson_tta_update(robot, self.global_step, self.IS_step) self.infoMap = copy.deepcopy(self.env.segmentation_info_mask.reshape((self.shape[1],self.shape[0])).T) self.updateInfoEntireTrajectory(robot_id) # Update metrics self.log_metrics(step=self.IS_step-1) ### Save a frame to generate gif of robot trajectories ### if self.save_image: robots_route = [ ([], []) ] # Assume 1 robot for point in self.trajectories_upscaled[robot_id]: robots_route[robot_id][0].append(point[0]) robots_route[robot_id][1].append(point[1]) if not os.path.exists(GIFS_PATH): os.makedirs(GIFS_PATH) if LOAD_AVS_BENCH: sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0] self.env.plot_env( self.global_step, GIFS_PATH, self.IS_step-1, max(travel_dist_list), robots_route, img_path_override=self.clip_seg_tta.img_paths[0], # Viz 1st sat_path_override=self.clip_seg_tta.imo_path, msk_name_override=self.clip_seg_tta.species_name, sound_id_override=sound_id_override, ) else: self.env.plot_env( self.global_step, GIFS_PATH, self.IS_step-1, max(travel_dist_list), robots_route ) # Log metrics if LOAD_AVS_BENCH: tax = Path(self.clip_seg_tta.gt_mask_name).stem self.perf_metrics['tax'] = " ".join(tax.split("_")[1:]) else: self.perf_metrics['tax'] = None travel_distances = [self.compute_travel_distance(traj) for traj in self.trajectories] self.perf_metrics['travel_dist'] = max(travel_distances) self.perf_metrics['travel_steps'] = self.IS_step self.perf_metrics['steps_to_first_tgt'] = self.steps_to_first_tgt self.perf_metrics['steps_to_mid_tgt'] = self.steps_to_mid_tgt self.perf_metrics['steps_to_last_tgt'] = self.steps_to_last_tgt self.perf_metrics['explored_rate'] = self.env.explored_rate self.perf_metrics['targets_found'] = self.env.targets_found_rate self.perf_metrics['targets_total'] = len(self.env.target_positions) if USE_CLIP_PREDS: self.perf_metrics['kmeans_k'] = self.kmeans_clusterer.final_k self.perf_metrics['tgts_gt_score'] = self.clip_seg_tta.tgts_gt_score self.perf_metrics['clip_inference_time'] = self.clip_seg_tta.clip_inference_time self.perf_metrics['tta_time'] = self.clip_seg_tta.tta_time else: self.perf_metrics['kmeans_k'] = None self.perf_metrics['tgts_gt_score'] = None self.perf_metrics['clip_inference_time'] = None self.perf_metrics['tta_time'] = None if FORCE_LOGGING_DONE_TGTS_FOUND and self.env.targets_found_rate == 1.0: self.perf_metrics['success_rate'] = True else: self.perf_metrics['success_rate'] = self.env.check_done()[0] # save gif if self.save_image: path = GIFS_PATH self.make_gif(path, self.global_step) print(YELLOW, f"[Eps {episode_number} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {self.IS_step}", NC) def asStride(self, arr, sub_shape, stride): """ Get a strided sub-matrices view of an ndarray. See also skimage.util.shape.view_as_windows() """ s0, s1 = arr.strides[:2] m1, n1 = arr.shape[:2] m2, n2 = sub_shape view_shape = (1+(m1-m2)//stride[0], 1+(n1-n2)//stride[1], m2, n2)+arr.shape[2:] strides = (stride[0]*s0, stride[1]*s1, s0, s1)+arr.strides[2:] subs = np.lib.stride_tricks.as_strided(arr, view_shape, strides=strides) return subs def pooling(self, mat, ksize, stride=None, method='max', pad=False): """ Overlapping pooling on 2D or 3D data. : ndarray, input array to pool. : tuple of 2, kernel size in (ky, kx). : tuple of 2 or None, stride of pooling window. If None, same as (non-overlapping pooling). : str, 'max for max-pooling, 'mean' for mean-pooling. : bool, pad or not. If no pad, output has size (n-f)//s+1, n being size, f being kernel size, s stride. if pad, output has size ceil(n/s). Return : pooled matrix. """ m, n = mat.shape[:2] ky, kx = ksize if stride is None: stride = (ky, kx) sy, sx = stride _ceil = lambda x, y: int(np.ceil(x/float(y))) if pad: ny = _ceil(m,sy) nx = _ceil(n,sx) size = ((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:] mat_pad = np.full(size,np.nan) mat_pad[:m,:n,...] = mat else: mat_pad = mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...] view = self.asStride(mat_pad,ksize,stride) if method == 'max': result = np.nanmax(view,axis=(2,3)) else: result = np.nanmean(view,axis=(2,3)) return result def compute_travel_distance(self, trajectory): distance = 0.0 for i in range(1, len(trajectory)): # Convert the tuple positions to numpy arrays for easy computation. prev_pos = np.array(trajectory[i-1]) curr_pos = np.array(trajectory[i]) # Euclidean distance between consecutive positions. distance += np.linalg.norm(curr_pos - prev_pos) return distance ################################################################################ # SPPP Related Fns ################################################################################ def log_metrics(self, step): # Update tgt found metrics if self.steps_to_first_tgt is None and self.env.num_targets_found == 1: self.steps_to_first_tgt = step + 1 if self.steps_to_mid_tgt is None and self.env.num_targets_found == int(len(self.env.target_positions) / 2): self.steps_to_mid_tgt = step + 1 if self.steps_to_last_tgt is None and self.env.num_targets_found == len(self.env.target_positions): self.steps_to_last_tgt = step + 1 def transpose_flat_idx(self, idx, H=NUM_COORDS_HEIGHT, W=NUM_COORDS_WIDTH): """ Transpose a flat index from an ``H×W`` grid to the equivalent position in the ``W×H`` transposed grid while **keeping the result in 1-D**. """ # --- Safety check to catch out-of-range indices --- assert 0 <= idx < H * W, f"idx {idx} out of bounds for shape ({H}, {W})" # Original (row, col) row, col = divmod(idx, W) # After transpose these coordinates swap row_T, col_T = col, row # Flatten back into 1-D (row-major) for the W×H grid return row_T * H + col_T def poisson_tta_update(self, robot, episode, step): # Generate Kmeans Clusters Stats # Scale index back to CLIP_GRIDS_DIMS to be compatible with CLIP patch size if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # High-res remap via pixel coordinates preserves exact neighbourhood filt_traj_coords, filt_targets_found_on_path = self.scale_trajectory( robot.trajectory_coords, self.env.target_positions, old_dims=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH), full_dims=(512, 512), new_dims=(CLIP_GRIDS_DIMS[0], CLIP_GRIDS_DIMS[1]) ) else: filt_traj_coords = [self.transpose_flat_idx(idx) for idx in robot.trajectory_coords] filt_targets_found_on_path = robot.targets_found_on_path region_stats_dict = self.kmeans_clusterer.compute_region_statistics( self.kmeans_sat_embeds_clusters, self.clip_seg_tta.heatmap_unnormalized, filt_traj_coords, episode_num=episode, step_num=step ) # Prep & execute TTA self.step_since_tta += 1 if robot.targets_found_on_path[-1] or self.step_since_tta % STEPS_PER_TTA == 0: num_cells = self.clip_seg_tta.heatmap.shape[0] * self.clip_seg_tta.heatmap.shape[1] pos_sample_weight_scale, neg_sample_weight_scale = [], [] for i, sample_loc in enumerate(filt_traj_coords): label = self.kmeans_clusterer.get_label_id(self.kmeans_sat_embeds_clusters, sample_loc) num_patches = region_stats_dict[label]['num_patches'] patches_visited = region_stats_dict[label]['patches_visited'] expectation = region_stats_dict[label]['expectation'] # Exponent like focal loss to wait for more samples before confidently decreasing pos_weight = 4.0 neg_weight = min(1.0, (patches_visited/(3*num_patches))**GAMMA_EXPONENT) pos_sample_weight_scale.append(pos_weight) neg_sample_weight_scale.append(neg_weight) # Adaptative LR (as samples increase, increase LR to fit more datapoints) adaptive_lr = MIN_LR + (MAX_LR - MIN_LR) * (step / num_cells) # TTA Update self.clip_seg_tta.execute_tta( filt_traj_coords, filt_targets_found_on_path, tta_steps=NUM_TTA_STEPS, lr=adaptive_lr, pos_sample_weight=pos_sample_weight_scale, neg_sample_weight=neg_sample_weight_scale, reset_weights=RESET_WEIGHTS ) if NUM_COORDS_WIDTH != CLIP_GRIDS_DIMS[0] or NUM_COORDS_HEIGHT != CLIP_GRIDS_DIMS[1]: # If heatmap is resized from clip original dims heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) self.env.segmentation_info_mask = np.expand_dims(heatmap.T.flatten(), axis=1) unnormalized_heatmap = self.convert_heatmap_resolution(self.clip_seg_tta.heatmap_unnormalized, full_dims=(512, 512), new_dims=(NUM_COORDS_WIDTH, NUM_COORDS_HEIGHT)) self.env.segmentation_info_mask_unnormalized = np.expand_dims(unnormalized_heatmap.T.flatten(), axis=1) print("~Resized heatmap to", NUM_COORDS_WIDTH, "x", NUM_COORDS_HEIGHT) else: self.env.segmentation_info_mask = np.expand_dims(self.clip_seg_tta.heatmap.T.flatten(), axis=1) self.env.segmentation_info_mask_unnormalized = np.expand_dims(self.clip_seg_tta.heatmap_unnormalized.T.flatten(), axis=1) self.step_since_tta = 0 def convert_heatmap_resolution(self, heatmap, full_dims=(512, 512), new_dims=(24, 24)): heatmap_large = resize(heatmap, full_dims, order=1, # order=1 → bilinear mode='reflect', anti_aliasing=True) coords = self.env.graph_generator.grid_coords # (N, N, 2) rows, cols = coords[...,1], coords[...,0] heatmap_resized = heatmap_large[rows, cols] heatmap_resized = heatmap_resized.reshape(new_dims[1], new_dims[0]) return heatmap_resized def convert_labelmap_resolution(self, labelmap, full_dims=(512, 512), new_dims=(24, 24)): """ 1) Upsample via nearest‐neighbor to full_dims 2) Sample back down to your graph grid using grid_coords """ # 1) Upsample with nearest‐neighbor, preserving integer labels up = resize( labelmap, full_dims, order=0, # nearest‐neighbor mode='edge', # padding mode preserve_range=True, # don't normalize labels anti_aliasing=False # must be False for labels ).astype(labelmap.dtype) # back to original integer dtype # 2) Downsample via your precomputed grid coords (N×N×2) coords = self.env.graph_generator.grid_coords # shape (N, N, 2) rows = coords[...,1].astype(int) cols = coords[...,0].astype(int) small = up[rows, cols] # shape (N, N) small = small.reshape(new_dims[0], new_dims[1]) return small def scale_trajectory(self, flat_indices, targets, old_dims=(17, 17), full_dims=(512, 512), new_dims=(24, 24)): """ Args: flat_indices: list of ints in [0..old_H*old_W-1] targets: list of (y_pix, x_pix) in [0..full_H-1] old_dims: (old_H, old_W) full_dims: (full_H, full_W) new_dims: (new_H, new_W) Returns: new_flat_traj: list of unique flattened indices in new_H×new_W counts: list of ints, same length as new_flat_traj """ old_H, old_W = old_dims full_H, full_W = full_dims new_H, new_W = new_dims # 1) bin targets into new grid cell_h_new = full_H / new_H cell_w_new = full_W / new_W grid_counts = [[0]*new_W for _ in range(new_H)] for x_pix, y_pix in targets: # note (x, y) order as in original implementation i_t = min(int(y_pix / cell_h_new), new_H - 1) j_t = min(int(x_pix / cell_w_new), new_W - 1) grid_counts[i_t][j_t] += 1 # 2) Walk the trajectory indices and project each old cell's *entire # pixel footprint* onto the finer 24×24 grid. cell_h_full = full_H / old_H cell_w_full = full_W / old_W seen = set() new_flat_traj = [] for node_idx in flat_indices: if node_idx < 0 or node_idx >= len(self.env.graph_generator.node_coords): continue coord_xy = self.env.graph_generator.node_coords[node_idx] try: row_old, col_old = self.env.graph_generator.find_index_from_grid_coords_2d(coord_xy) except Exception: continue # Bounding box of the old cell in full-resolution pixel space y0 = row_old * cell_h_full y1 = (row_old + 1) * cell_h_full x0 = col_old * cell_w_full x1 = (col_old + 1) * cell_w_full # Which new-grid rows & cols overlap? (inclusive ranges) i_start = max(0, min(int(y0 / cell_h_new), new_H - 1)) i_end = max(0, min(int((y1 - 1) / cell_h_new), new_H - 1)) j_start = max(0, min(int(x0 / cell_w_new), new_W - 1)) j_end = max(0, min(int((x1 - 1) / cell_w_new), new_W - 1)) for ii in range(i_start, i_end + 1): for jj in range(j_start, j_end + 1): f_new = ii * new_W + jj if f_new not in seen: seen.add(f_new) new_flat_traj.append(f_new) # 3) annotate counts counts = [] for f in new_flat_traj: i_new, j_new = divmod(f, new_W) counts.append(grid_counts[i_new][j_new]) return new_flat_traj, counts ################################################################################ def make_gif(self, path, n): """ Generate a gif given list of images """ with imageio.get_writer('{}/{}_target_rate_{:.2f}.gif'.format(path, n, self.env.targets_found_rate), mode='I', fps=5) as writer: for frame in self.env.frame_files: image = imageio.imread(frame) writer.append_data(image) print('gif complete\n') # Remove files for filename in self.env.frame_files[:-1]: os.remove(filename) # For KMeans gif if LOAD_AVS_BENCH and USE_CLIP_PREDS: with imageio.get_writer('{}/{}_kmeans_stats.gif'.format(path, n), mode='I', fps=5) as writer: for frame in self.kmeans_clusterer.kmeans_frame_files: image = imageio.imread(frame) writer.append_data(image) print('Kmeans Clusterer gif complete\n') # Remove files for filename in self.kmeans_clusterer.kmeans_frame_files[:-1]: os.remove(filename) # IS gif with imageio.get_writer('{}/{}_IS.gif'.format(path, n), mode='I', fps=5) as writer: for frame in self.IS_frame_files: image = imageio.imread(frame) writer.append_data(image) print('Kmeans Clusterer gif complete\n') # Remove files for filename in self.IS_frame_files[:-1]: os.remove(filename) ################################################################################ class Agent: def __init__(self, ID, infoMap=None, uncertaintyMap=None, shape=None, row=0, col=0, sensorSize=9, numAgents=8): self.ID = ID self.row = row self.col = col self.numAgents = numAgents self.sensorSize = sensorSize def setLocation(self, row, col): self.row = row self.col = col def getLocation(self): return [self.row, self.col] def move(self, action): """ No movement: 0 North (-1,0): 1 East (0,1): 2 South (1,0): 3 West (0,-1): 4 LeftUp (-1,-1) : 5 RightUP (-1,1) :6 RightDown (1,1) :7 RightLeft (1,-1) :8 check valid action of the agent. be sure not to be out of the boundary """ if action == 0: return 0 elif action == 1: self.row -= 1 elif action == 2: self.col += 1 elif action == 3: self.row += 1 elif action == 4: self.col -= 1 elif action == 5: self.row -= 1 self.col -= 1 elif action == 6: self.row -= 1 self.col += 1 elif action == 7: self.row += 1 self.col += 1 elif action == 8: self.row += 1 self.col -= 1 def reverseMove(self, action): if action == 0: return 0 elif action == 1: self.row += 1 elif action == 2: self.col -= 1 elif action == 3: self.row -= 1 elif action == 4: self.col += 1 elif action == 5: self.row += 1 self.col += 1 elif action == 6: self.row += 1 self.col -= 1 elif action == 7: self.row -= 1 self.col -= 1 elif action == 8: self.row -= 1 self.col += 1 else: print("agent can only move NESW/1234") sys.exit() class Target: def __init__(self, row, col, ID, time_found=np.nan): self.row = row self.col = col self.ID = ID self.time_found = time_found self.status = None self.time_visited = time_found def getLocation(self): return self.row, self.col def updateFound(self, timeStep): if np.isnan(self.time_found): self.time_found = timeStep def updateVisited(self, timeStep): if np.isnan(self.time_visited): self.time_visited = timeStep if __name__ == "__main__": search_env = Env(map_index=1, k_size=K_SIZE, n_agent=NUM_ROBOTS, plot=SAVE_GIFS) IS_info_map = search_env.segmentation_info_mask IS_agent_loc = search_env.start_positions IS_target_loc = [[312, 123], [123, 312], [312, 312], [123, 123]] env = ISEnv(state=[IS_info_map, IS_agent_loc, IS_target_loc], shape=(NUM_COORDS_HEIGHT, NUM_COORDS_WIDTH)) env.is_scenario(NUM_EPS_STEPS) print()