Spaces:
Sleeping
Sleeping
| import utils | |
| import os | |
| import math | |
| import json | |
| import jsbeautifier | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import mne | |
| from mne.channels import read_custom_montage | |
| from scipy.interpolate import Rbf | |
| from scipy.optimize import linear_sum_assignment | |
| from sklearn.neighbors import NearestNeighbors | |
| def reorder_data(idx_order, fill_flags, inputname, filename): | |
| # read the input data | |
| raw_data = utils.read_train_data(inputname) | |
| #print(raw_data.shape) | |
| new_data = np.zeros((30, raw_data.shape[1])) | |
| zero_arr = np.zeros((1, raw_data.shape[1])) | |
| for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)): | |
| if flag == False: | |
| new_data[i, :] = raw_data[idx_set[0], :] | |
| elif idx_set == []: | |
| new_data[i, :] = zero_arr | |
| else: | |
| tmp_data = [raw_data[j, :] for j in idx_set] | |
| new_data[i, :] = np.mean(tmp_data, axis=0) | |
| utils.save_data(new_data, filename) | |
| return raw_data.shape | |
| def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, outputname): | |
| # read the denoised data | |
| d_data = utils.read_train_data(filename) | |
| if batch_cnt == 0: | |
| new_data = np.zeros((raw_data_shape[0], d_data.shape[1])) | |
| #print(new_data.shape) | |
| else: | |
| new_data = utils.read_train_data(outputname) | |
| for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)): | |
| if flag == False: # ignore if this channel was filled using "fillmode" | |
| new_data[idx_set[0], :] = d_data[i, :] | |
| utils.save_data(new_data, outputname) | |
| return | |
| def get_matched(tpl_order, tpl_dict): | |
| return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True] | |
| def get_empty_templates(tpl_order, tpl_dict): | |
| return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False] | |
| def get_unassigned_inputs(in_order, in_dict): | |
| return [channel for channel in in_order if in_dict[channel]["assigned"]==False] | |
| def read_montage_data(loc_file): | |
| tpl_montage = read_custom_montage("./template_chanlocs.loc") | |
| in_montage = read_custom_montage(loc_file) | |
| tpl_order = tpl_montage.ch_names | |
| in_order = in_montage.ch_names | |
| tpl_dict = {} | |
| in_dict = {} | |
| # convert all channel names to uppercase and store their information | |
| for i, channel in enumerate(tpl_order): | |
| up_channel = str.upper(channel) | |
| tpl_montage.rename_channels({channel: up_channel}) | |
| tpl_dict[up_channel] = { | |
| "index" : i, | |
| "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel], | |
| "matched" : False | |
| } | |
| for i, channel in enumerate(in_order): | |
| up_channel = str.upper(channel) | |
| in_montage.rename_channels({channel: up_channel}) | |
| in_dict[up_channel] = { | |
| "index" : i, | |
| "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel], | |
| "assigned" : False | |
| } | |
| return tpl_montage, in_montage, tpl_dict, in_dict | |
| def save_figures(channel_info, tpl_montage, filename1, filename2): | |
| tpl_order = channel_info["templateOrder"] | |
| in_order = channel_info["inputOrder"] | |
| tpl_dict = channel_info["templateDict"] | |
| in_dict = channel_info["inputDict"] | |
| tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order] | |
| tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order] | |
| in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order] | |
| in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order] | |
| tpl_coords = np.vstack((tpl_x, tpl_y)).T | |
| in_coords = np.vstack((in_x, in_y)).T | |
| # extract template's head figure | |
| tpl_fig = tpl_montage.plot() | |
| tpl_ax = tpl_fig.axes[0] | |
| lines = tpl_ax.lines | |
| head_lines = [] | |
| for line in lines: | |
| x, y = line.get_data() | |
| head_lines.append((x,y)) | |
| # -------------------------plot input montage------------------------------ | |
| fig = plt.figure(figsize=(6.4,6.4), dpi=100) | |
| ax = fig.add_subplot(111) | |
| fig.tight_layout() | |
| ax.set_aspect('equal') | |
| ax.axis('off') | |
| # plot template's head | |
| for x, y in head_lines: | |
| ax.plot(x, y, color='black', linewidth=1.0) | |
| # plot in_channels on it | |
| ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black') | |
| for i, channel in enumerate(in_order): | |
| ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center') | |
| # save input_montage | |
| fig.savefig(filename1) | |
| # ---------------------------add indications------------------------------- | |
| # plot unmatched input channels in red | |
| indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False] | |
| if indices != []: | |
| ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red') | |
| for i in indices: | |
| ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center') | |
| # save mapped_montage | |
| fig.savefig(filename2) | |
| # ------------------------------------------------------------------------- | |
| # store the tpl and in_channels' display positions (in px). | |
| tpl_coords = ax.transData.transform(tpl_coords) | |
| in_coords = ax.transData.transform(in_coords) | |
| plt.close('all') | |
| for i, channel in enumerate(tpl_order): | |
| css_left = (tpl_coords[i,0]-11)/6.4 | |
| css_bottom = (tpl_coords[i,1]-7)/6.4 | |
| tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"] | |
| for i, channel in enumerate(in_order): | |
| css_left = (in_coords[i,0]-11)/6.4 | |
| css_bottom = (in_coords[i,1]-7)/6.4 | |
| in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"] | |
| channel_info.update({ | |
| "templateDict" : tpl_dict, | |
| "inputDict" : in_dict | |
| }) | |
| return channel_info | |
| def align_coords(channel_info, tpl_montage, in_montage): | |
| tpl_order = channel_info["templateOrder"] | |
| in_order = channel_info["inputOrder"] | |
| tpl_dict = channel_info["templateDict"] | |
| in_dict = channel_info["inputDict"] | |
| matched = get_matched(tpl_order, tpl_dict) | |
| # 2D alignment (for visualization purposes) | |
| fig = [tpl_montage.plot(), in_montage.plot()] | |
| ax = [fig[0].axes[0], fig[1].axes[0]] | |
| # extract the displayed 2D coordinates | |
| all_tpl = ax[0].collections[0].get_offsets().data | |
| all_in= ax[1].collections[0].get_offsets().data | |
| matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched]) | |
| matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched]) | |
| plt.close('all') | |
| # apply TPS to transform in_channels to align with tpl_channels positions | |
| rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate') | |
| rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate') | |
| # apply the transformation to all in_channels | |
| transformed_in_x = rbf_x(all_in[:,0], all_in[:,1]) | |
| transformed_in_y = rbf_y(all_in[:,0], all_in[:,1]) | |
| transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T | |
| for i, channel in enumerate(tpl_order): | |
| tpl_dict[channel]["coord_2d"] = all_tpl[i] | |
| for i, channel in enumerate(in_order): | |
| in_dict[channel]["coord_2d"] = transformed_in[i].tolist() | |
| # 3D alignment | |
| all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order]) | |
| all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order]) | |
| matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched]) | |
| matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched]) | |
| rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate') | |
| rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate') | |
| rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate') | |
| transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2]) | |
| transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2]) | |
| transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2]) | |
| transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T | |
| for i, channel in enumerate(in_order): | |
| in_dict[channel]["coord_3d"] = transformed_in[i].tolist() | |
| channel_info.update({ | |
| "templateDict" : tpl_dict, | |
| "inputDict" : in_dict | |
| }) | |
| return channel_info | |
| def find_neighbors(channel_info, missing_channels, new_idx): | |
| in_order = channel_info["inputOrder"] | |
| tpl_dict = channel_info["templateDict"] | |
| in_dict = channel_info["inputDict"] | |
| all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order] | |
| empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels] | |
| # use KNN to choose k nearest channels | |
| k = 4 if len(in_order)>4 else len(in_order) | |
| knn = NearestNeighbors(n_neighbors=k, metric='euclidean') | |
| knn.fit(all_in) | |
| for i, channel in enumerate(missing_channels): | |
| distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1)) | |
| idx = tpl_dict[channel]["index"] | |
| new_idx[idx] = indices[0].tolist() | |
| return new_idx | |
| def match_names(stage1_info): | |
| # read the location file | |
| loc_file = stage1_info["fileNames"]["inputLocation"] | |
| tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file) | |
| tpl_order = tpl_montage.ch_names | |
| in_order = in_montage.ch_names | |
| new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels | |
| fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode" | |
| alias_dict = { | |
| 'T3': 'T7', | |
| 'T4': 'T8', | |
| 'T5': 'P7', | |
| 'T6': 'P8' | |
| } | |
| for i, channel in enumerate(tpl_order): | |
| if channel in alias_dict and alias_dict[channel] in in_dict: | |
| tpl_montage.rename_channels({channel: alias_dict[channel]}) | |
| tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel) | |
| channel = alias_dict[channel] | |
| if channel in in_dict: | |
| new_idx[i] = [in_dict[channel]["index"]] | |
| fill_flags[i] = False | |
| tpl_dict[channel]["matched"] = True | |
| in_dict[channel]["assigned"] = True | |
| # update the names | |
| tpl_order = tpl_montage.ch_names | |
| stage1_info.update({ | |
| "unassignedInputs" : get_unassigned_inputs(in_order, in_dict), | |
| "missingTemplates" : get_empty_templates(tpl_order, tpl_dict), | |
| "mappingResults" : [ | |
| { | |
| "newOrder" : new_idx, | |
| "fillFlags" : fill_flags | |
| } | |
| ] | |
| }) | |
| channel_info = { | |
| "templateOrder" : tpl_order, | |
| "inputOrder" : in_order, | |
| "templateDict" : tpl_dict, | |
| "inputDict" : in_dict | |
| } | |
| return stage1_info, channel_info, tpl_montage, in_montage | |
| def optimal_mapping(channel_info): | |
| tpl_order = channel_info["templateOrder"] | |
| in_order = channel_info["inputOrder"] | |
| tpl_dict = channel_info["templateDict"] | |
| in_dict = channel_info["inputDict"] | |
| unassigned = get_unassigned_inputs(in_order, in_dict) | |
| # reset all tpl.matched to False | |
| for channel in tpl_dict: | |
| tpl_dict[channel]["matched"] = False | |
| all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order]) | |
| unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned]) | |
| # initialize the cost matrix for the Hungarian algorithm | |
| if len(unassigned) < 30: | |
| cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row | |
| else: | |
| cost_matrix = np.zeros((30, len(unassigned))) | |
| # fill the cost matrix with Euclidean distances between tpl and unassigned in_channels | |
| for i in range(30): | |
| for j in range(len(unassigned)): | |
| cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000) | |
| # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel | |
| # by minimizing the total distances between their positions. | |
| row_idx, col_idx = linear_sum_assignment(cost_matrix) | |
| # store the mapping results | |
| new_idx = [[]]*30 | |
| fill_flags = [True]*30 | |
| for i, j in zip(row_idx, col_idx): | |
| if j < len(unassigned): # filter out dummy channels | |
| tpl_channel = tpl_order[i] | |
| in_channel = unassigned[j] | |
| new_idx[i] = [in_dict[in_channel]["index"]] | |
| fill_flags[i] = False | |
| tpl_dict[tpl_channel]["matched"] = True | |
| in_dict[in_channel]["assigned"] = True | |
| #print(f'{tpl_channel}({i}) <- {in_channel}({j})') | |
| # fill the remaining empty tpl_channels | |
| missing_channels = get_empty_templates(tpl_order, tpl_dict) | |
| if missing_channels != []: | |
| new_idx = find_neighbors(channel_info, missing_channels, new_idx) | |
| result = { | |
| "newOrder" : new_idx, | |
| "fillFlags" : fill_flags | |
| } | |
| channel_info.update({ | |
| "templateDict" : tpl_dict, | |
| "inputDict" : in_dict | |
| }) | |
| return result, channel_info | |
| def mapping_result(stage1_info, channel_info, filename): | |
| unassigned_num = len(stage1_info["unassignedInputs"]) | |
| batch_num = math.ceil(unassigned_num/30) + 1 | |
| # map the remaining in_channels | |
| results = stage1_info["mappingResults"] | |
| for i in range(1, batch_num): | |
| # optimally select 30 in_channels to map to the tpl_channels based on proximity | |
| result, channel_info = optimal_mapping(channel_info) | |
| results += [result] | |
| data = { | |
| #"templateOrder" : channel_info["templateOrder"], | |
| #"inputOrder" : channel_info["inputOrder"], | |
| "batchNum" : batch_num, | |
| "mappingResults" : results | |
| } | |
| options = jsbeautifier.default_options() | |
| options.indent_size = 4 | |
| res = jsbeautifier.beautify(json.dumps(data), options) | |
| with open(filename, 'w') as jsonfile: | |
| jsonfile.write(res) | |
| stage1_info.update({ | |
| "batchNum" : batch_num, | |
| "mappingResults" : results | |
| }) | |
| return stage1_info, channel_info | |