Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import csv | |
| from model import cumbersome_model2 | |
| from model import UNet_family | |
| from model import UNet_attention | |
| from model import tf_model | |
| from model import tf_data | |
| import time | |
| import torch | |
| import os | |
| import random | |
| import shutil | |
| from scipy.signal import decimate, resample_poly, firwin, lfilter | |
| os.environ["CUDA_VISIBLE_DEVICES"]="0" | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| def resample(signal, fs, tgt_fs): | |
| # downsample the signal to the target sample rate | |
| if fs>tgt_fs: | |
| fs_down = tgt_fs # Desired sample rate | |
| q = int(fs / fs_down) # Downsampling factor | |
| signal_new = [] | |
| for ch in signal: | |
| x_down = decimate(ch, q) | |
| signal_new.append(x_down) | |
| # upsample the signal to the target sample rate | |
| elif fs<tgt_fs: | |
| fs_up = tgt_fs # Desired sample rate | |
| p = int(fs_up / fs) # Upsampling factor | |
| signal_new = [] | |
| for ch in signal: | |
| x_up = resample_poly(ch, p, 1) | |
| signal_new.append(x_up) | |
| else: | |
| signal_new = signal | |
| signal_new = np.array(signal_new).astype(np.float64) | |
| return signal_new | |
| def FIR_filter(signal, lowcut, highcut): | |
| fs = 256.0 | |
| # Number of FIR filter taps | |
| numtaps = 1000 | |
| # Use firwin to create a bandpass FIR filter | |
| fir_coeff = firwin(numtaps, [lowcut, highcut], pass_zero=False, fs=fs) | |
| # Apply the filter to signal: | |
| filtered_signal = lfilter(fir_coeff, 1.0, signal) | |
| return filtered_signal | |
| def read_train_data(file_name): | |
| with open(file_name, 'r', newline='') as f: | |
| lines = csv.reader(f) | |
| data = [] | |
| for line in lines: | |
| data.append(line) | |
| data = np.array(data).astype(np.float64) | |
| return data | |
| def cut_data(filepath, raw_data): | |
| raw_data = np.array(raw_data).astype(np.float64) | |
| total = int(len(raw_data[0]) / 1024) | |
| for i in range(total): | |
| table = raw_data[:, i * 1024:(i + 1) * 1024] | |
| filename = filepath + 'temp2/' + str(i) + '.csv' | |
| with open(filename, 'w', newline='') as csvfile: | |
| writer = csv.writer(csvfile) | |
| writer.writerows(table) | |
| return total | |
| def glue_data(file_name, total): | |
| gluedata = 0 | |
| for i in range(total): | |
| file_name1 = file_name + 'output{}.csv'.format(str(i)) | |
| with open(file_name1, 'r', newline='') as f: | |
| lines = csv.reader(f) | |
| raw_data = [] | |
| for line in lines: | |
| raw_data.append(line) | |
| raw_data = np.array(raw_data).astype(np.float64) | |
| #print(i) | |
| if i == 0: | |
| gluedata = raw_data | |
| else: | |
| smooth = (gluedata[:, -1] + raw_data[:, 1]) / 2 | |
| gluedata[:, -1] = smooth | |
| raw_data[:, 1] = smooth | |
| gluedata = np.append(gluedata, raw_data, axis=1) | |
| #print(gluedata.shape) | |
| return gluedata | |
| def save_data(data, filename): | |
| with open(filename, 'w', newline='') as csvfile: | |
| writer = csv.writer(csvfile) | |
| writer.writerows(data) | |
| def dataDelete(path): | |
| try: | |
| shutil.rmtree(path) | |
| except OSError as e: | |
| pass | |
| #print(e) | |
| else: | |
| pass | |
| #print("The directory is deleted successfully") | |
| def decode_data(data, std_num, mode=5): | |
| if mode == "ICUNet": | |
| # 1. read name | |
| model = cumbersome_model2.UNet1(n_channels=30, n_classes=30).to(device) | |
| resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar' | |
| # 2. load model | |
| checkpoint = torch.load(resumeLoc, map_location=device) | |
| model.load_state_dict(checkpoint['state_dict'], False) | |
| model.eval() | |
| # 3. decode strategy | |
| with torch.no_grad(): | |
| data = data[np.newaxis, :, :] | |
| data = torch.Tensor(data).to(device) | |
| decode = model(data) | |
| elif mode == "ICUNet++" or mode == "ICUNet_attn": | |
| # 1. read name | |
| if mode == "ICUNet++": | |
| model = UNet_family.NestedUNet3(num_classes=30).to(device) | |
| elif mode == "ICUNet_attn": | |
| model = UNet_attention.UNetpp3_Transformer(num_classes=30).to(device) | |
| resumeLoc = './model/' + mode + '/modelsave' + '/checkpoint.pth.tar' | |
| # 2. load model | |
| checkpoint = torch.load(resumeLoc, map_location=device) | |
| model.load_state_dict(checkpoint['state_dict'], False) | |
| model.eval() | |
| # 3. decode strategy | |
| with torch.no_grad(): | |
| data = data[np.newaxis, :, :] | |
| data = torch.Tensor(data).to(device) | |
| decode1, decode2, decode = model(data) | |
| elif mode == "ART": | |
| # 1. read name | |
| resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar' | |
| # 2. load model | |
| checkpoint = torch.load(resumeLoc, map_location=device) | |
| model = tf_model.make_model(30, 30, N=2).to(device) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model.eval() | |
| # 3. decode strategy | |
| with torch.no_grad(): | |
| data = torch.FloatTensor(data).to(device) | |
| data = data.unsqueeze(0) | |
| src = data | |
| tgt = data # you can modify to randomize data | |
| batch = tf_data.Batch(src, tgt, 0) | |
| out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask) | |
| decode = model.generator(out) | |
| decode = decode.permute(0, 2, 1) | |
| add_tensor = torch.zeros(1, 30, 1).to(device) | |
| decode = torch.cat((decode, add_tensor), dim=2) | |
| # 4. numpy | |
| #print(decode.shape) | |
| decode = np.array(decode.cpu()).astype(np.float64) | |
| return decode | |
| def reorder_data(raw_data, mapping_result): | |
| new_data = np.zeros((30, raw_data.shape[1])) | |
| zero_arr = np.zeros((1, raw_data.shape[1])) | |
| for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])): | |
| if flag == True: | |
| new_data[i, :] = raw_data[indices[0], :] | |
| elif indices[0] == None: | |
| new_data[i, :] = zero_arr | |
| else: | |
| data = [raw_data[idx, :] for idx in indices] | |
| new_data[i, :] = np.mean(data, axis=0) | |
| return new_data | |
| def preprocessing(filepath, inputfile, samplerate, mapping_result): | |
| # establish temp folder | |
| try: | |
| os.mkdir(filepath+"temp2/") | |
| except OSError as e: | |
| dataDelete(filepath+"temp2/") | |
| os.mkdir(filepath+"temp2/") | |
| print(e) | |
| # read data | |
| signal = read_train_data(inputfile) | |
| #print(signal.shape) | |
| # channel mapping | |
| signal = reorder_data(signal, mapping_result) | |
| #print(signal.shape) | |
| # resample | |
| signal = resample(signal, samplerate, 256) | |
| #print(signal.shape) | |
| # FIR_filter | |
| signal = FIR_filter(signal, 1, 50) | |
| #print(signal.shape) | |
| # cutting data | |
| total_file_num = cut_data(filepath, signal) | |
| return total_file_num | |
| def restore_order(data, all_data, mapping_result): | |
| for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])): | |
| if flag == True: | |
| all_data[indices[0], :] = data[i, :] | |
| return all_data | |
| def postprocessing(data, samplerate, outputfile, mapping_result, batch_cnt, channel_num): | |
| # resample to original sampling rate | |
| data = resample(data, 256, samplerate) | |
| # reverse channel mapping | |
| all_data = np.zeros((channel_num, data.shape[1])) if batch_cnt==0 else read_train_data(outputfile) | |
| all_data = restore_order(data, all_data, mapping_result) | |
| # save data | |
| save_data(all_data, outputfile) | |
| # model = tf.keras.models.load_model('./denoise_model/') | |
| def reconstruct(model_name, total, filepath, batch_cnt): | |
| # -------------------decode_data--------------------------- | |
| second1 = time.time() | |
| for i in range(total): | |
| file_name = filepath + 'temp2/{}.csv'.format(str(i)) | |
| data_noise = read_train_data(file_name) | |
| std = np.std(data_noise) | |
| avg = np.average(data_noise) | |
| data_noise = (data_noise-avg)/std | |
| # Deep Learning Artifact Removal | |
| d_data = decode_data(data_noise, std, model_name) | |
| d_data = d_data[0] | |
| outputname = filepath + 'temp2/output{}.csv'.format(str(i)) | |
| save_data(d_data, outputname) | |
| # --------------------glue_data---------------------------- | |
| data = glue_data(filepath+"temp2/", total) | |
| # -------------------delete_data--------------------------- | |
| dataDelete(filepath+"temp2/") | |
| second2 = time.time() | |
| print(f"Using {model_name} model to reconstruct batch-{batch_cnt+1} has been success in {second2 - second1} sec(s)") | |
| return data | |