🔥码云GVP开源项目 12k star Uniapp+ElementUI 功能强大 支持多语言、二开方便! 广告
# 神经网络交易的训练部分 > 来源:https://uqer.io/community/share/55b8af12f9f06c91f818c607 ```py import pybrain as brain training_set = ("20050101", "20130101") # 训练集(六年) testing_set = ("20150101", "20150525") # 测试集(2015上半年数据) universe = ['000001'] # 目标股票池 HISTORY = 10 # 通过前十日数据预测 ``` ```py from pybrain.datasets import SupervisedDataSet ### 建立数据集 def make_training_data(): ds = SupervisedDataSet(HISTORY, 1) for ticker in universe: # 遍历每支股票 raw_data = DataAPI.MktEqudGet(ticker=ticker, beginDate=training_set[0], endDate=training_set[1], field=[ 'tradeDate', 'closePrice' # 敏感字段 ], pandas="1") plist = list(raw_data['closePrice']) for idx in range(1, len(plist) - HISTORY - 1): sample = [] for i in range(HISTORY): sample.append(plist[idx + i - 1] / plist[idx + i] - 1) answer = plist[idx + HISTORY - 1] / plist[idx + HISTORY] - 1 ds.addSample(sample, answer) return ds ### 建立测试集 def make_testing_data(): ds = SupervisedDataSet(HISTORY, 1) for ticker in universe: # 遍历每支股票 raw_data = DataAPI.MktEqudGet(ticker=ticker, beginDate=testing_set[0], endDate=testing_set[1], field=[ 'tradeDate', 'closePrice' # 敏感字段 ], pandas="1") plist = list(raw_data['closePrice']) for idx in range(1, len(plist) - HISTORY - 1): sample = [] for i in range(HISTORY): sample.append(plist[idx + i - 1] / plist[idx + i] - 1) answer = plist[idx + HISTORY - 1] / plist[idx + HISTORY] - 1 ds.addSample(sample, answer) return ds ``` ```py from pybrain.supervised.trainers import BackpropTrainer ### 构造BP训练实例 def make_trainer(net, ds, momentum = 0.1, verbose = True, weightdecay = 0.01): # 网络, 训练集, 训练参数 trainer = BackpropTrainer(net, ds, momentum = momentum, verbose = verbose, weightdecay = weightdecay) return trainer ``` ```py ### 开始训练 def start_training(trainer, epochs = 15): # 迭代次数 trainer.trainEpochs(epochs) def start_testing(net, dataset): return net.activateOnDataset(dataset) ``` ```py ### 保存参数 from pybrain.tools.customxml import NetworkWriter def save_arguments(net): NetworkWriter.writeToFile(net, 'huge_data.csv') print 'Arguments save to file net.csv' ``` ```py from pybrain.tools.shortcuts import buildNetwork ### 初始化神经网络 fnn = buildNetwork(HISTORY, 15, 7, 1) training_dataset = make_training_data() testing_dataset = make_testing_data() trainer = make_trainer(fnn, training_dataset) start_training(trainer, 5) save_arguments(fnn) print start_testing(fnn, testing_dataset) Total error: 0.00226884924246 Total error: 0.00058242191557 Total error: 0.00058089738079 Total error: 0.000581061747831 Total error: 0.000580708420341 Arguments save to file net.csv```