一、图表
1.图表类型
import matplotlib.pyplot as plt plt.hist() #频数直方图 plt.plot() #线图,传入序列,元组、列表、numpy.ndarray plt.pie() plt.bar() plt.show() plt.scatter()
2.画图
fig = plt.figure() 创建一块画布 #将fig分成2*2,1表示是第一个图 ax1 = fig.add_subplot(2,2,1)
二、双均线策略
import pandas as pd import numpy as np import matplotlib.pyplot as plt df = pd.read_csv('601318.csv', index_col='date', parse_dates=['date']) df['ma5'] = np.nan df['ma10'] = np.nan # 第一步计算ma # 循环计算,速度非常慢 #df.loc只能传行或列的名进去,loc左边是行,右边是列 # for i in range(4, len(df)): # df.loc[df.index[i], 'ma5'] = df['close'][i-4:i+1].mean() # for i in range(9, len(df)): # df.loc[df.index[i], 'ma10'] = df['close'][i-9:i+1].mean() # 方案2:cumsum # close = [10, 11, 12, 13, 14, 15, 16] # close.cumsum=[10, 21, 33, 46, 60, 75, 91] # - - - # [nan,nan,nan,nan,0, 10, 21, 33, 46, 60, 75, 91] # sr = df['close'].cumsum() # df['ma5'] = (sr - sr.shift(1).fillna(0).shift(4))/5 # df['ma10'] = (sr - sr.shift(1).fillna(0).shift(9))/10 # 方案3:rolling df['ma5'] = df['close'].rolling(5).mean() df['ma10'] = df['close'].rolling(10).mean() df = df.dropna() df[['ma5', 'ma10']].plot() plt.show() # 第二部 判断金叉死叉 # 方案一 # 金叉 短期<=长期 短期>长期 # 死叉 短期>=长期 短期<长期 # sr = df['ma5'] <= df['ma10'] # # golden_cross = [] # death_cross = [] # for i in range(1, len(sr)): # # if sr.iloc[i] == True and sr.iloc[i + 1] == False: 开始想的是加1,但是索引溢出 # if sr.iloc[i - 1] == True and sr.iloc[i] == False: # golden_cross.append(sr.index[i]) # if sr.iloc[i - 1] == False and sr.iloc[i] == True: # death_cross.append(sr.index[i]) # 方案2 golden_cross = df[(df['ma5'] <= df['ma10']) & (df['ma5'] > df['ma10']).shift(1)].index death_cross = df[(df['ma5'] >= df['ma10']) & (df['ma5'] < df['ma10']).shift(1)].index
三、一个简单的回测框架
成果展示:
代码:
import pandas as pd import matplotlib.pyplot as plt import tushare import datetime import dateutil ''' 获取所有的股票交易日,交易日信息保存在csv文件 ''' try: trade_cal = pd.read_csv("trade_cal.csv") except: trade_cal = tushare.trade_cal() trade_cal.to_csv("trade_cal.csv") class Context: def __init__(self, cash, start_date, end_date): ''' 保存股票信息 :param cash: 现金量 :param start_date: 量化策略开始时间 :param end_date: 量化策略结束时间 :param positions: 持仓股票和对应的数量 :param benchmark: 参考股票 :param date_range: 开始-结束之间的所有交易日 :param dt: 当前日期 (循环时当前日期会发生变化) ''' self.cash = cash self.start_date = start_date self.end_date = end_date self.positions = {} # 持仓信息 self.benchmark = None self.date_range = trade_cal[(trade_cal['isOpen']==1)& (trade_cal['calendarDate']>=start_date)& (trade_cal['calendarDate']<=end_date)]['calendarDate'].values self.dt = None class G: ''' 保存用户的全局参数 ''' pass ''' 默认的初始化信息 ''' g = G() CASH = 100000 START_DATE = '2016-01-07' END_DATE = '2017-01-31' context = Context(CASH,START_DATE,END_DATE) def attribute_history(security, count, field=('open','close','high','low','volume')): ''' 获取某股票count天的历史行情,每运行一次该函数,日期范围后移 :param security: 股票代码 :param count: 天数 :param field: 字段 :return: ''' end_date = (context.dt - datetime.timedelta(days=1)).strftime('%Y-%m-%d') start_date = trade_cal[(trade_cal['isOpen']==1)& (trade_cal['calendarDate']<=end_date)][-count:]['calendarDate'].iloc[0] return attribute_daterange_history(security,start_date,end_date,field) def attribute_daterange_history(security, start_date,end_date, field=('open','close','high','low','volume')): ''' 底层,获取某股票某一段时间的历史行情 :param security: :param start_date: :param end_date: :param field: :return: ''' df = tushare.get_k_data(security,start_date,end_date) df.index = df['date'] return df[list(field)] def get_today_data(security): ''' 获取context的"当天"的股票信息,停牌返回Null :param security: :return: ''' try: today = context.dt.strftime('%Y-%m-%d') df = tushare.get_k_data(security,today,today) df.index = df['date'] data = df.loc[today] except KeyError: # 股票停牌 data = pd.Series() return data def _order(today_data, security, amount): ''' 底层买股票的函数 :param today_data: "当天"的股票价格OCHL :param security: 股票代码 :param amount: 交易股数,正数为买入,负数为卖出 :return: ''' p = today_data['open'] # 找不到该股票默认为0股 old_amount = context.positions.get(security, 0) if len(today_data) == 0: print("今日停牌") return if context.cash - amount * p < 0: amount = context.cash // p print('%s:现金不足,已调整为%d' %(today_data['date'],amount)) if amount % 100 != 0: # 买或卖不是100的倍数就调整为100的倍数,卖光则不调整 if amount != -old_amount: # 2345 => 2300 amount = int(amount / 100) * 100 print('%s:不是100的倍数,已调整为%d' %(today_data['date'],amount)) if old_amount < -amount: amount = -old_amount print('%s:卖出股票不能超过持仓数,已调整为%d'%(today_data['date'],amount)) # 更新持仓信息 context.positions[security] = old_amount + amount # 更新钱 context.cash -= amount*p # 持仓为0就删掉 if context.positions[security] == 0: del context.positions[security] def order(security, amount): # 买入股票。amount为正表示买入,负表示卖出 today_data = get_today_data(security) _order(today_data, security, amount) def order_target(security, amount): # 把股票交易到多少股,不能为负数,比原来小是卖出,比原来大是买入 if amount < 0: print("数量不能为负,已调整为0") amount = 0 today_data = get_today_data(security) hold_amount = context.positions.get(security, 0) # TODO: T + 1 closeable total delta_amount = amount - hold_amount _order(today_data,security,delta_amount) def order_value(security, value): # 买多少钱的股票或者卖多少钱的股票 today_data = get_today_data(security) amount = value / today_data['open'] _order(today_data,security,amount) def order_target_value(security, value): # 买到或者卖到多少钱 if value < 0: print("价值不能为负,已调整为0") value = 0 today_data = get_today_data(security) hold_value = context.positions.get(security,0) * today_data['open'] dalta_value = value - hold_value order_value(security,dalta_value) def run(): plt_df = pd.DataFrame(index=pd.to_datetime(context.date_range), columns=['value']) # 最初的钱,算收益率用 init_value = context.cash # 保存停牌前一天的股票价格 last_price = {} # 用户接口1 initialize(context) for dt in context.date_range: context.dt = dateutil.parser.parse(dt) # 用户接口2 handle_data(context) # 股票和现金的总价值 value = context.cash for stock in context.positions: # 考虑停牌的情况 today_data = get_today_data(stock) if len(today_data) == 0: p = last_price[stock] else: p = today_data['open'] last_price[stock] = p value += p * context.positions[stock] plt_df.loc[dt, 'value'] = value plt_df['ratio'] = (plt_df['value']-init_value) / init_value bm_df = attribute_daterange_history(context.benchmark, context.start_date, context.end_date) bm_init = bm_df['open'][0] plt_df['benchmark_raito'] = (bm_df['open']-bm_init) / bm_init print(plt_df) plt_df[['ratio','benchmark_raito']].plot() plt.show() ''' initialize和handle_data是用户的操作 ''' def initialize(context): context.benchmark = '601318' g.p1 = 5 g.p2 = 60 g.security = '601318' def handle_data(context): hist = attribute_history(g.security, g.p2) ma5 = hist['close'][-g.p1:].mean() ma60 = hist['close'].mean() if ma5 > ma60 and g.security not in context.positions: order_value(g.security, context.cash) elif ma5 < ma60 and g.security in context.positions: order_target(g.security,0) if __name__ == '__main__': run()
相关说明:
tushare.trade_cal() # 获取交易日信息,输出结果为:
calendarDate isOpen
0 1990/12/19 1
1 1990/12/20 1
2 1990/12/21 1
3 1990/12/22 0
4 1990/12/23 0
5 1990/12/24 1
6 1990/12/25 1
7 1990/12/26 1
8 1990/12/27 1
9 1990/12/28 1
10 1990/12/29 0
11 1990/12/30 0
12 1990/12/31 1
13 1991/1/1 0