博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
RNN
阅读量:5313 次
发布时间:2019-06-14

本文共 3040 字,大约阅读时间需要 10 分钟。

import numpy as npimport os, pickle, random, datetimefrom keras.models import Sequentialfrom keras.layers import Dense, Activation, LSTMFOLDERS = [    {"class": 1, "folder": "/data1/linzn/data/ch_g729a_100_10000ms_FEAT"},    # The folder that contains positive data files.    {"class": 0, "folder": "/data1/linzn/data/ch_g729a_0_10000ms_FEAT"}  # The folder that contains negative data files.]SAMPLE_LENGTH = 10000  # The sample length (ms)BATCH_SIZE = 32  # batch sizeITER = 30  # number of iterationFOLD = 5  # = NUM_SAMPLE / number of testing samples'''Get the paths of all files in the folder'''def get_file_list(folder):    file_list = []    for file in os.listdir(folder):        file_list.append(os.path.join(folder, file))    return file_list'''Read codeword file-------------input    file_path        The path to an ASCII file.        Each line contains three integers: x1 x2 x3, which are the three codewords of the frame.        There are (number of frame) lines in total. output    the list of codewords'''def parse_sample(file_path):    file = open(file_path, 'r')    lines = file.readlines()    sample = []    for line in lines:        line_split = line.strip("\r\n\t").strip().split("\t")        sample.append(line_split)    return sample'''Save variable in pickle'''def save_variable(file_name, variable):    file_object = open(file_name, "wb")    pickle.dump(variable, file_object)    file_object.close()'''Pruned RNN-SM training and testing'''if __name__ == '__main__':    all_files = [(item, folder["class"]) for folder in FOLDERS for item in get_file_list(folder["folder"])]    random.shuffle(all_files)    save_variable('all_files.pkl', all_files)    all_samples_x = [(parse_sample(item[0])) for item in all_files]    all_samples_y = [item[1] for item in all_files]    np_all_samples_x = np.asarray(all_samples_x)    np_all_samples_y = np.asarray(all_samples_y)    save_variable('np_all_samples_x.pkl', np_all_samples_x)    save_variable('np_all_samples_y.pkl', np_all_samples_y)    file_num = len(all_files)    sub_file_num = int(file_num / FOLD)    x_test = np_all_samples_x[0: sub_file_num]  # The samples for testing    y_test = np_all_samples_y[0: sub_file_num]  # The label of the samples for testing    x_train = np_all_samples_x[sub_file_num: file_num]  # The samples for training    y_train = np_all_samples_y[sub_file_num: file_num]  # The label of the samples for training    print("Building model")    model = Sequential()    model.add(LSTM(50, input_length=int(SAMPLE_LENGTH / 10), input_dim=3, return_sequences=True))  # first layer     model.add(LSTM(50))  # second layer    model.add(Dense(1))  # output layer    model.add(Activation('sigmoid'))  # activation function    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=["accuracy"])    print("Training")    for i in range(ITER):        model.fit(x_train, y_train, batch_size=BATCH_SIZE, nb_epoch=1, validation_data=(x_test, y_test))        model.save('model_%d.h5' % (i + 1))

  

转载于:https://www.cnblogs.com/wangzc521/p/10606679.html

你可能感兴趣的文章
Day2----hiddenMeau
查看>>
保留你的dSYM文件
查看>>
将iPhone5s中的相片批量下载到电脑中
查看>>
union和union all
查看>>
3.5 [ Enterprise Library ]注入模型设计
查看>>
网易首页导航封装类优化
查看>>
[转]Java连接各种数据库的方法
查看>>
项目经理如何把工作简单化
查看>>
【笔记】html的改变(上)
查看>>
String(三)
查看>>
ABAP术语-Application Server
查看>>
Spring的IOC原理
查看>>
Ubuntu学习
查看>>
见鬼吧,拉格朗日插值法
查看>>
HBase简介及原理
查看>>
团队-象棋游戏-代码设计规范
查看>>
javascript中对象的深复制的几种方法
查看>>
缓冲区溢出攻击
查看>>
【转】RESTful Web Services初探
查看>>
分享一下我习惯用的快捷键
查看>>