roseDwayane commited on
Commit
788c373
·
1 Parent(s): 3c7e5a7
Files changed (5) hide show
  1. app.py +394 -0
  2. requirements.txt +19 -0
  3. template_chanlocs.loc.txt +30 -0
  4. template_montage.png +0 -0
  5. utils.py +233 -0
app.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os
4
+ import random
5
+ import utils
6
+ from channel_mapping import mapping, reorder_data
7
+
8
+ import mne
9
+ from mne.channels import read_custom_montage
10
+
11
+ quickstart = """
12
+ # Quickstart
13
+
14
+ ## 1. Channel mapping
15
+
16
+ ### Raw data
17
+ 1. The data need to be a two-dimensional array (channel, timepoint).
18
+ 2. Make sure you have **resampled** your data to **256 Hz**.
19
+ 3. Upload your EEG data in `.csv` format.
20
+
21
+ ### Channel locations
22
+ Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
23
+ >If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage.
24
+
25
+ ### Imputation
26
+ The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2`.
27
+ We expect your input data to include these channels as well.
28
+ If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
29
+
30
+ <u>Manually</u>:
31
+ - **mean**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
32
+
33
+ <u>Automatically</u>:
34
+ Firstly, we will attempt to find neighboring channel to use as alternative. For instance, if the required channel is **FC3** but you only have **FC1**, we will use it as a replacement for **FC3**.
35
+ Then, depending on the **Imputation** way you chose, we will:
36
+ - **zero**: fill the missing channels with zeros.
37
+ - **adjacent**: fill the missing channels using neighboring channels which are located closer to the center. For example, if the required channel is **FC3** but you only have **F3, C3**, then we will choose **C3** as the imputing value for **FC3**.
38
+ >Note: The imputed channels **need to be removed** after the data being reconstructed.
39
+
40
+ ### Mapping result
41
+ Once the mapping process is finished, the **template montage** and the **input montage**(with the channels choosen by the mapping function displaying their names) will be shown.
42
+
43
+ ### Missing channels
44
+ The channels displayed here are those for which the template didn't find suitable channels to use, and utilized **Imputation** to fill the missing values.
45
+ Therefore, you need to
46
+ <span style="color:red">**remove these channels**</span>
47
+ after you download the denoised data.
48
+
49
+ ### Template location file
50
+ You need to use this as the **new location file** for the denoised data.
51
+
52
+ ## 2. Decode data
53
+
54
+ ### Model
55
+ Select the model you want to use.
56
+ The detailed description of the models can be found in other pages.
57
+
58
+ """
59
+
60
+ icunet = """
61
+ # IC-U-Net
62
+ ### Abstract
63
+ Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting.
64
+ """
65
+
66
+ chkbox_js = """
67
+ (state_json) => {
68
+ state_json = JSON.parse(JSON.stringify(state_json));
69
+ if(state_json.state == "finished") return;
70
+
71
+ document.querySelector("#chs-chkbox>div:nth-of-type(2)").style.cssText = `
72
+ position: relative;
73
+ width: 560px;
74
+ height: 560px;
75
+ background: url("file=${state_json.files.raw_montage}");
76
+ `;
77
+
78
+ let all_chkbox = document.querySelectorAll("#chs-chkbox> div:nth-of-type(2)> label");
79
+ all_chkbox = Array.apply(null, all_chkbox);
80
+
81
+ all_chkbox.forEach((item, index) => {
82
+ let channel = state_json.inputByIndex[index];
83
+ let left = state_json.inputByName[channel].css_position[0];
84
+ let bottom = state_json.inputByName[channel].css_position[1];
85
+ //console.log(`left: ${left}, bottom: ${bottom}`);
86
+
87
+ item.style.cssText = `
88
+ position: absolute;
89
+ left: ${left};
90
+ bottom: ${bottom};
91
+ `;
92
+ item.className = "";
93
+ item.querySelector("span").innerText = "";
94
+ });
95
+
96
+ }
97
+ """
98
+
99
+
100
+ with gr.Blocks() as demo:
101
+
102
+ state_json = gr.JSON(elem_id="state", visible=False)
103
+
104
+ with gr.Row():
105
+ gr.Markdown(
106
+ """
107
+
108
+ """
109
+ )
110
+ with gr.Row():
111
+ with gr.Column():
112
+ gr.Markdown(
113
+ """
114
+ # 1.Channel Mapping
115
+ """
116
+ )
117
+ with gr.Row():
118
+ in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
119
+ in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
120
+ with gr.Row():
121
+ in_fill_mode = gr.Dropdown(choices=["zero",
122
+ ("adjacent channel", "adjacent"),
123
+ ("mean (manually select channels)", "mean")],
124
+ value="zero",
125
+ label="Imputation",
126
+ scale=2)
127
+ map_btn = gr.Button("Mapping", scale=1)
128
+ channels_json = gr.JSON(visible=False)
129
+ res_md = gr.Markdown(
130
+ """
131
+ ### Mapping result:
132
+ """,
133
+ visible=False
134
+ )
135
+ with gr.Row():
136
+ tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
137
+ map_montage = gr.Image(label="Choosen channels", visible=False)
138
+ chs_chkbox = gr.CheckboxGroup(elem_id="chs-chkbox", label="", visible=False)
139
+ next_btn = gr.Button("Next", interactive=False, visible=False)
140
+ miss_txtbox = gr.Textbox(label="Missing channels", visible=False)
141
+ tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False)
142
+ with gr.Column():
143
+ gr.Markdown(
144
+ """
145
+ # 2.Decode Data
146
+ """
147
+ )
148
+ with gr.Row():
149
+ in_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART", "(mapped data)"],
150
+ value="ICUNet",
151
+ label="Model",
152
+ scale=2)
153
+ run_btn = gr.Button(scale=1, interactive=False)
154
+ out_denoised_data = gr.File(label="Denoised data")
155
+
156
+
157
+ with gr.Row():
158
+ with gr.Tab("EEGART"):
159
+ gr.Markdown()
160
+ with gr.Tab("IC-U-Net"):
161
+ gr.Markdown(icunet)
162
+ with gr.Tab("IC-U-Net++"):
163
+ gr.Markdown()
164
+ with gr.Tab("IC-U-Net-Att"):
165
+ gr.Markdown()
166
+ with gr.Tab("QuickStart"):
167
+ gr.Markdown(quickstart)
168
+
169
+ #demo.load(js=js)
170
+
171
+ def reset_layout(raw_data):
172
+ # establish temp folder
173
+ filepath = os.path.dirname(str(raw_data))
174
+ try:
175
+ os.mkdir(filepath+"/temp_data/")
176
+ except OSError as e:
177
+ utils.dataDelete(filepath+"/temp_data/")
178
+ os.mkdir(filepath+"/temp_data/")
179
+ #print(e)
180
+ state_obj = {
181
+ "filepath": filepath+"/temp_data/",
182
+ "files": {}
183
+ }
184
+ return {state_json : state_obj,
185
+ chs_chkbox : gr.CheckboxGroup(choices=[], value=[], label="", visible=False), # choices, value ???
186
+ next_btn : gr.Button("Next", interactive=False, visible=False),
187
+ run_btn : gr.Button(interactive=False),
188
+ tpl_montage : gr.Image(visible=False),
189
+ map_montage : gr.Image(value=None, visible=False),
190
+ miss_txtbox : gr.Textbox(visible=False),
191
+ res_md : gr.Markdown(visible=False),
192
+ tpl_loc_file : gr.File(visible=False)}
193
+
194
+ def mapping_result(state_obj, channels_obj, raw_data, fill_mode):
195
+ state_obj.update(channels_obj)
196
+
197
+ if fill_mode=="mean" and channels_obj["missingChannelsIndex"]!=[]:
198
+ state_obj.update({
199
+ "state" : "initializing",
200
+ "fillingCount" : 0,
201
+ "totalFillingNum" : len(channels_obj["missingChannelsIndex"])-1
202
+ })
203
+ #print("Missing channels:", state_obj["missingChannelsIndex"])
204
+ return {state_json : state_obj,
205
+ next_btn : gr.Button(visible=True)}
206
+ else:
207
+ reorder_data(raw_data, channels_obj["newOrder"], fill_mode, state_obj)
208
+
209
+ missing_channels = [state_obj["templateByIndex"][idx] for idx in state_obj["missingChannelsIndex"]]
210
+ missing_channels = ', '.join(missing_channels)
211
+
212
+ state_obj.update({
213
+ "state" : "finished",
214
+ #"fillingCount" : -1,
215
+ #"totalFillingNum" : -1
216
+ })
217
+ return {state_json : state_obj,
218
+ res_md : gr.Markdown(visible=True),
219
+ miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
220
+ tpl_loc_file : gr.File(visible=True),
221
+ run_btn : gr.Button(interactive=True)}
222
+
223
+ def show_montage(state_obj, raw_loc):
224
+ filepath = state_obj["filepath"]
225
+ raw_montage = read_custom_montage(raw_loc)
226
+
227
+ # convert all channel names to uppercase
228
+ for i in range(len(raw_montage.ch_names)):
229
+ channel = raw_montage.ch_names[i]
230
+ raw_montage.rename_channels({channel: str.upper(channel)})
231
+
232
+ if state_obj["state"] == "initializing":
233
+ filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
234
+ state_obj["files"]["raw_montage"] = filename
235
+ raw_fig = raw_montage.plot()
236
+ raw_fig.set_size_inches(5.6, 5.6)
237
+ raw_fig.savefig(filename, pad_inches=0)
238
+
239
+ return {state_json : state_obj}#,
240
+ #tpl_montage : gr.Image(visible=True),
241
+ #in_montage : gr.Image(value=filename, visible=True),
242
+ #map_montage : gr.Image(visible=False)}
243
+
244
+ elif state_obj["state"] == "finished":
245
+ # didn't find any way to hide the dark points...
246
+ # tmp
247
+ filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
248
+ state_obj["files"]["map_montage"] = filename
249
+
250
+ show_names= []
251
+ for channel in state_obj["inputByName"]:
252
+ if state_obj["inputByName"][channel]["used"]:
253
+ if channel=='CZ' and state_obj["CZImputed"]:
254
+ continue
255
+ show_names.append(channel)
256
+ mapped_fig = raw_montage.plot(show_names=show_names)
257
+ mapped_fig.set_size_inches(5.6, 5.6)
258
+ mapped_fig.savefig(filename, pad_inches=0)
259
+
260
+ return {state_json : state_obj,
261
+ tpl_montage : gr.Image(visible=True),
262
+ map_montage : gr.Image(value=filename, visible=True)}
263
+
264
+ elif state_obj["state"] == "selecting":
265
+ # update in_montage here ?
266
+ #return {in_montage : gr.Image()}
267
+ return {state_json : state_obj}
268
+
269
+ def generate_chkbox(state_obj):
270
+ if state_obj["state"] == "initializing":
271
+ in_channels = [channel for channel in state_obj["inputByName"]]
272
+ state_obj["state"] = "selecting"
273
+
274
+ first_idx = state_obj["missingChannelsIndex"][0]
275
+ first_name = state_obj["templateByIndex"][first_idx]
276
+ chkbox_label = first_name+' (1/'+str(state_obj["totalFillingNum"]+1)+')'
277
+ return {state_json : state_obj,
278
+ chs_chkbox : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
279
+ next_btn : gr.Button(interactive=True)}
280
+ else:
281
+ return {state_json : state_obj}
282
+
283
+
284
+ map_btn.click(
285
+ fn = reset_layout,
286
+ inputs = in_raw_data,
287
+ outputs = [state_json, chs_chkbox, next_btn, run_btn, tpl_montage, map_montage, miss_txtbox,
288
+ res_md, tpl_loc_file]
289
+
290
+ ).success(
291
+ fn = mapping,
292
+ inputs = [in_raw_data, in_raw_loc, in_fill_mode],
293
+ outputs = channels_json
294
+
295
+ ).success(
296
+ fn = mapping_result,
297
+ inputs = [state_json, channels_json, in_raw_data, in_fill_mode],
298
+ outputs = [state_json, chs_chkbox, next_btn, miss_txtbox, res_md, tpl_loc_file, run_btn]
299
+
300
+ ).success(
301
+ fn = show_montage,
302
+ inputs = [state_json, in_raw_loc],
303
+ outputs = [state_json, tpl_montage, map_montage]
304
+
305
+ ).success(
306
+ fn = generate_chkbox,
307
+ inputs = state_json,
308
+ outputs = [state_json, chs_chkbox, next_btn]
309
+ ).success(
310
+ fn = None,
311
+ js = chkbox_js,
312
+ inputs = state_json,
313
+ outputs = []
314
+ )
315
+
316
+
317
+ def check_next(state_obj, selected, raw_data, fill_mode):
318
+ if state_obj["state"] == "selecting":
319
+
320
+ # save info before clicking on next_btn
321
+ prev_target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]]
322
+ prev_target_name = state_obj["templateByIndex"][prev_target_idx]
323
+
324
+ selected_idx = [state_obj["inputByName"][channel]["index"] for channel in selected]
325
+ state_obj["newOrder"][prev_target_idx] = selected_idx
326
+
327
+ if len(selected)==1 and state_obj["inputByName"][selected[0]]["used"]==False:
328
+ state_obj["inputByName"][selected[0]]["used"] = True
329
+ state_obj["missingChannelsIndex"][state_obj["fillingCount"]] = -1
330
+
331
+ print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
332
+
333
+ # update next round
334
+ state_obj["fillingCount"] += 1
335
+ if state_obj["fillingCount"] <= state_obj["totalFillingNum"]:
336
+ target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]]
337
+ target_name = state_obj["templateByIndex"][target_idx]
338
+ chkbox_label = target_name+' ('+str(state_obj["fillingCount"]+1)+'/'+str(state_obj["totalFillingNum"]+1)+')'
339
+ btn_label = "Submit" if state_obj["fillingCount"]==state_obj["totalFillingNum"] else "Next"
340
+
341
+ return {state_json : state_obj,
342
+ chs_chkbox : gr.CheckboxGroup(value=[], label=chkbox_label),
343
+ next_btn : gr.Button(btn_label)}
344
+ else:
345
+ state_obj["state"] = "finished"
346
+ reorder_data(raw_data, state_obj["newOrder"], fill_mode, state_obj)
347
+
348
+ missing_channels = []
349
+ for idx in state_obj["missingChannelsIndex"]:
350
+ if idx != -1:
351
+ missing_channels.append(state_obj["templateByIndex"][idx])
352
+ missing_channels = ', '.join(missing_channels)
353
+
354
+ return {state_json : state_obj,
355
+ chs_chkbox : gr.CheckboxGroup(visible=False),
356
+ next_btn : gr.Button(visible=False),
357
+ res_md : gr.Markdown(visible=True),
358
+ miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
359
+ tpl_loc_file : gr.File(visible=True),
360
+ run_btn : gr.Button(interactive=True)}
361
+
362
+ next_btn.click(
363
+ fn = check_next,
364
+ inputs = [state_json, chs_chkbox, in_raw_data, in_fill_mode],
365
+ outputs = [state_json, chs_chkbox, next_btn, run_btn, res_md, miss_txtbox, tpl_loc_file]
366
+
367
+ ).success(
368
+ fn = show_montage,
369
+ inputs = [state_json, in_raw_loc],
370
+ outputs = [state_json, tpl_montage, map_montage]
371
+ )
372
+
373
+
374
+ @run_btn.click(inputs=[state_json, in_raw_data, in_model_name], outputs=out_denoised_data)
375
+ def run_model(state_obj, raw_file, model_name):
376
+ filepath = state_obj["filepath"]
377
+
378
+ input_name = os.path.basename(str(raw_file))
379
+ output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
380
+
381
+ if model_name == "(mapped data)":
382
+ return filepath + 'mapped.csv'
383
+
384
+ # step1: Data preprocessing
385
+ total_file_num = utils.preprocessing(filepath, 'mapped.csv', 256)
386
+
387
+ # step2: Signal reconstruction
388
+ utils.reconstruct(model_name, total_file_num, filepath, output_name)
389
+
390
+ return filepath + output_name
391
+
392
+
393
+ if __name__ == "__main__":
394
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2023.5.7
2
+ charset-normalizer==3.1.0
3
+ filelock==3.12.0
4
+ idna==3.4
5
+ Jinja2==3.1.2
6
+ MarkupSafe==2.1.2
7
+ mpmath==1.3.0
8
+ networkx==3.1
9
+ numpy==1.24.3
10
+ Pillow==9.5.0
11
+ requests==2.31.0
12
+ scipy==1.10.1
13
+ sympy==1.12
14
+ torch==2.0.1
15
+ torchaudio==2.0.2
16
+ torchvision==0.15.2
17
+ typing_extensions==4.6.2
18
+ urllib3==2.0.2
19
+ mne==1.7.0
template_chanlocs.loc.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1 -19.33 0.52497 FP1
2
+ 2 19.385 0.52499 FP2
3
+ 3 -58.847 0.54399 F7
4
+ 4 -43.411 0.33339 F3
5
+ 5 0.30571 0.22978 FZ
6
+ 6 43.668 0.34149 F4
7
+ 7 58.694 0.54439 F8
8
+ 8 -80.084 0.54296 FT7
9
+ 9 -69.321 0.27328 FC3
10
+ 10 0.7867 0.095376 FCZ
11
+ 11 69.152 0.27863 FC4
12
+ 12 79.329 0.54305 FT8
13
+ 13 -100.78 0.53459 T3
14
+ 14 -100.09 0.25493 C3
15
+ 15 177.5 0.029055 CZ
16
+ 16 99.225 0.26068 C4
17
+ 17 100.01 0.53482 T4
18
+ 18 -118.48 0.52323 TP7
19
+ 19 -126.49 0.27946 CP3
20
+ 20 179.53 0.14139 CPZ
21
+ 21 125 0.28397 CP4
22
+ 22 118.03 0.52338 TP8
23
+ 23 -135.4 0.50767 T5
24
+ 24 -146.07 0.33054 P3
25
+ 25 179.77 0.24709 PZ
26
+ 26 144.68 0.33093 P4
27
+ 27 135 0.50782 T6
28
+ 28 -165.34 0.47584 O1
29
+ 29 179.95 0.45961 OZ
30
+ 30 165.1 0.47591 O2
template_montage.png ADDED
utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import csv
3
+ from model import cumbersome_model2
4
+ from model import UNet_family
5
+ from model import UNet_attention
6
+ from model import tf_model
7
+ from model import tf_data
8
+
9
+ import time
10
+ import torch
11
+ import os
12
+ import random
13
+ import shutil
14
+ from scipy.signal import decimate, resample_poly, firwin, lfilter
15
+
16
+
17
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
18
+
19
+ def resample(signal, fs):
20
+ # downsample the signal to a sample rate of 256 Hz
21
+ if fs>256:
22
+ fs_down = 256 # Desired sample rate
23
+ q = int(fs / fs_down) # Downsampling factor
24
+ signal_new = []
25
+ for ch in signal:
26
+ x_down = decimate(ch, q)
27
+ signal_new.append(x_down)
28
+
29
+ # upsample the signal to a sample rate of 256 Hz
30
+ elif fs<256:
31
+ fs_up = 256 # Desired sample rate
32
+ p = int(fs_up / fs) # Upsampling factor
33
+ signal_new = []
34
+ for ch in signal:
35
+ x_up = resample_poly(ch, p, 1)
36
+ signal_new.append(x_up)
37
+
38
+ else:
39
+ signal_new = signal
40
+
41
+ signal_new = np.array(signal_new).astype(np.float64)
42
+
43
+ return signal_new
44
+
45
+ def FIR_filter(signal, lowcut, highcut):
46
+ fs = 256.0
47
+ # Number of FIR filter taps
48
+ numtaps = 1000
49
+ # Use firwin to create a bandpass FIR filter
50
+ fir_coeff = firwin(numtaps, [lowcut, highcut], pass_zero=False, fs=fs)
51
+ # Apply the filter to signal:
52
+ filtered_signal = lfilter(fir_coeff, 1.0, signal)
53
+
54
+ return filtered_signal
55
+
56
+
57
+ def read_train_data(file_name):
58
+ with open(file_name, 'r', newline='') as f:
59
+ lines = csv.reader(f)
60
+ data = []
61
+ for line in lines:
62
+ data.append(line)
63
+
64
+ data = np.array(data).astype(np.float64)
65
+ return data
66
+
67
+
68
+ def cut_data(filepath, raw_data):
69
+ raw_data = np.array(raw_data).astype(np.float64)
70
+ total = int(len(raw_data[0]) / 1024)
71
+ for i in range(total):
72
+ table = raw_data[:, i * 1024:(i + 1) * 1024]
73
+ filename = filepath + '/temp2/' + str(i) + '.csv'
74
+ with open(filename, 'w', newline='') as csvfile:
75
+ writer = csv.writer(csvfile)
76
+ writer.writerows(table)
77
+ return total
78
+
79
+
80
+ def glue_data(file_name, total, output):
81
+ gluedata = 0
82
+ for i in range(total):
83
+ file_name1 = file_name + 'output{}.csv'.format(str(i))
84
+ with open(file_name1, 'r', newline='') as f:
85
+ lines = csv.reader(f)
86
+ raw_data = []
87
+ for line in lines:
88
+ raw_data.append(line)
89
+ raw_data = np.array(raw_data).astype(np.float64)
90
+ #print(i)
91
+ if i == 0:
92
+ gluedata = raw_data
93
+ else:
94
+ smooth = (gluedata[:, -1] + raw_data[:, 1]) / 2
95
+ gluedata[:, -1] = smooth
96
+ raw_data[:, 1] = smooth
97
+ gluedata = np.append(gluedata, raw_data, axis=1)
98
+ #print(gluedata.shape)
99
+ filename2 = output
100
+ with open(filename2, 'w', newline='') as csvfile:
101
+ writer = csv.writer(csvfile)
102
+ writer.writerows(gluedata)
103
+ #print("GLUE DONE!" + filename2)
104
+
105
+
106
+ def save_data(data, filename):
107
+ with open(filename, 'w', newline='') as csvfile:
108
+ writer = csv.writer(csvfile)
109
+ writer.writerows(data)
110
+
111
+ def dataDelete(path):
112
+ try:
113
+ shutil.rmtree(path)
114
+ except OSError as e:
115
+ print(e)
116
+ else:
117
+ pass
118
+ #print("The directory is deleted successfully")
119
+
120
+
121
+ def decode_data(data, std_num, mode=5):
122
+
123
+ if mode == "ICUNet":
124
+ # 1. read name
125
+ model = cumbersome_model2.UNet1(n_channels=30, n_classes=30)
126
+ resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
127
+ # 2. load model
128
+ checkpoint = torch.load(resumeLoc, map_location='cpu')
129
+ model.load_state_dict(checkpoint['state_dict'], False)
130
+ model.eval()
131
+ # 3. decode strategy
132
+ with torch.no_grad():
133
+ data = data[np.newaxis, :, :]
134
+ data = torch.Tensor(data)
135
+ decode = model(data)
136
+
137
+
138
+ elif mode == "UNetpp" or mode == "AttUnet":
139
+ # 1. read name
140
+ if mode == "UNetpp":
141
+ model = UNet_family.NestedUNet3(num_classes=30)
142
+ elif mode == "AttUnet":
143
+ model = UNet_attention.UNetpp3_Transformer(num_classes=30)
144
+ resumeLoc = './model/'+ mode + '/modelsave' + '/checkpoint.pth.tar'
145
+ # 2. load model
146
+ checkpoint = torch.load(resumeLoc, map_location='cpu')
147
+ model.load_state_dict(checkpoint['state_dict'], False)
148
+ model.eval()
149
+ # 3. decode strategy
150
+ with torch.no_grad():
151
+ data = data[np.newaxis, :, :]
152
+ data = torch.Tensor(data)
153
+ decode1, decode2, decode = model(data)
154
+
155
+
156
+ elif mode == "EEGART":
157
+ # 1. read name
158
+ resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
159
+ # 2. load model
160
+ checkpoint = torch.load(resumeLoc, map_location='cpu')
161
+ model = tf_model.make_model(30, 30, N=2)
162
+ model.load_state_dict(checkpoint['state_dict'])
163
+ model.eval()
164
+ # 3. decode strategy
165
+ with torch.no_grad():
166
+ data = torch.FloatTensor(data)
167
+ data = data.unsqueeze(0)
168
+ src = data
169
+ tgt = data
170
+ batch = tf_data.Batch(src, tgt, 0)
171
+ out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
172
+ decode = model.generator(out)
173
+ decode = decode.permute(0, 2, 1)
174
+ #add_tensor = torch.zeros(1, 30, 1)
175
+ #decode = torch.cat((decode, add_tensor), dim=2)
176
+
177
+ # 4. numpy
178
+ #print(decode.shape)
179
+ decode = np.array(decode.cpu()).astype(np.float64)
180
+ return decode
181
+
182
+ def preprocessing(filepath, filename, samplerate):
183
+ # establish temp folder
184
+ try:
185
+ os.mkdir(filepath+"/temp2/")
186
+ except OSError as e:
187
+ dataDelete(filepath+"/temp2/")
188
+ os.mkdir(filepath+"/temp2/")
189
+ print(e)
190
+
191
+ # read data
192
+ signal = read_train_data(filepath+'/'+filename)
193
+ #print(signal.shape)
194
+ # resample
195
+ signal = resample(signal, samplerate)
196
+ #print(signal.shape)
197
+ # FIR_filter
198
+ signal = FIR_filter(signal, 1, 50)
199
+ #print(signal.shape)
200
+ # cutting data
201
+ total_file_num = cut_data(filepath, signal)
202
+
203
+ return total_file_num
204
+
205
+
206
+ # model = tf.keras.models.load_model('./denoise_model/')
207
+ def reconstruct(model_name, total, filepath, outputfile):
208
+ # -------------------decode_data---------------------------
209
+ second1 = time.time()
210
+ for i in range(total):
211
+ file_name = filepath + '/temp2/{}.csv'.format(str(i))
212
+ data_noise = read_train_data(file_name)
213
+
214
+ std = np.std(data_noise)
215
+ avg = np.average(data_noise)
216
+
217
+ data_noise = (data_noise-avg)/std
218
+
219
+ # Deep Learning Artifact Removal
220
+ d_data = decode_data(data_noise, std, model_name)
221
+ d_data = d_data[0]
222
+
223
+ outputname = filepath + '/temp2/output{}.csv'.format(str(i))
224
+ save_data(d_data, outputname)
225
+
226
+ # --------------------glue_data----------------------------
227
+ glue_data(filepath+"/temp2/", total, filepath+'/'+outputfile)
228
+ # -------------------delete_data---------------------------
229
+ dataDelete(filepath+"/temp2/")
230
+ second2 = time.time()
231
+
232
+ print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
233
+