Created
June 8, 2020 13:12
-
-
Save zew13/7ed9470c3137f773b0b934a7750204f9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| from fbprophet import Prophet | |
| from glob import glob | |
| from pandas import Series,DataFrame | |
| from os.path import join,dirname,basename | |
| from datetime import date | |
| from collections import defaultdict | |
| import pandas as pd | |
| import csv | |
| import re | |
| RE_NUM = re.compile("\\d+") | |
| def next_quarter(numLi): | |
| m = numLi[1]+3 | |
| y = numLi[0] | |
| if m > 12: | |
| m = m % 12 | |
| y += 1 | |
| return (y,m,1) | |
| DEPTH = 3 * 4 | |
| # 日期 = "2016-12-31 2017-09-30 2017-12-31 2018-09-30 2018-12-31 2019-03-31 2019-06-30 2019-09-30 2019-12-31 2020-03-31".split() | |
| # 数据 = dict( | |
| # 佣金收入 = "132.01 327.29 524.43 632.04 708.20 635.48 677.21 624.44 732.67 1427.30", | |
| # 金融服务收入 = "3.27 32.28 82.89 157.42 171.93 208.61 191.14 189.89 203.04 163.62", | |
| # 利息收入 = "1.08 5.29 51.88 259.84 505.59 833.21 477.30", | |
| # 其他收入 = "1.62 2.45 1.52 10.98 32.14 64.03 162.08 299.10 228.09 250.56" | |
| # ) | |
| # | |
| # 预测 = defaultdict(list) | |
| def extend(ds, li): | |
| df = DataFrame({ | |
| "ds":list(map(lambda x:str(date(*x)),ds)), | |
| "y":li | |
| }) | |
| m = Prophet( | |
| weekly_seasonality=False, | |
| n_changepoints=len(li)//12, | |
| daily_seasonality=False, | |
| seasonality_prior_scale=.5, | |
| changepoint_range=1 | |
| ) | |
| m.fit(df) | |
| future = m.make_future_dataframe(periods=DEPTH,freq="Q",include_history=False) | |
| forecast = m.predict(future) | |
| future.tail() | |
| # r = forecast[['ds','yhat', 'yhat_lower', 'yhat_upper']] | |
| # print(list(forecast.yhat)) | |
| return list(forecast.yhat) | |
| # | |
| # # fig1 = m.plot(forecast) | |
| # # fig2 = m.plot_components(forecast) | |
| # | |
| # # from statsmodels.tsa.api import Holt | |
| # # import numpy as np | |
| # # import matplotlib.pyplot as plt | |
| # # | |
| # # fit = Holt(np.asarray(income)).fit(smoothing_level=1.406, smoothing_slope=0.2) | |
| # # | |
| # # for i in fit.forecast(4): | |
| # # print(round(i,2)) | |
| # | |
| # # plt.figure() | |
| # # plt.plot(income[:-1]+list(fit.forecast(5))) | |
| # # plt.show() | |
| _DIR = dirname(__file__) | |
| def csv_future(csvname): | |
| with open(join(_DIR,"csv",csvname+'.csv'),encoding="utf-8-sig") as f: | |
| with open(join(_DIR,"future",csvname+'.csv'),"w", encoding="utf-8-sig") as out: | |
| out = csv.writer(out) | |
| f = csv.reader(f) | |
| dateLi = [] | |
| for pos,i in enumerate(f): | |
| if pos: | |
| key = i[0].strip() | |
| # if key!="经纪佣金": | |
| # continue | |
| # if key!="薪酬及福利费用": | |
| # continue | |
| if not key: | |
| continue | |
| ds_li = [] | |
| val_li = [] | |
| row_li = [] | |
| org_li = [] | |
| for d,val in zip(dateLi, i[1:]): | |
| val = val.strip() | |
| if val: | |
| val = val.replace(',','') | |
| val = float(val) | |
| org_li.append(val) | |
| month = d[1] | |
| quarter = month//3 | |
| if not ds_li or (ds_li[-1][0] != d[0]): | |
| val = val/quarter | |
| else: | |
| val = (val-org_li[-2]) /((d[1] - ds_li[-1][1])//3) | |
| row_li.append(str(round(val,4))) | |
| ds_li.append(d) | |
| val_li.append(val) | |
| else: | |
| row_li.append("") | |
| # print(val_li) | |
| # print(ds_li) | |
| out.writerow( | |
| [key]+row_li+list(map(lambda x:str(round(x,4)),extend(ds_li,val_li))) | |
| ) | |
| else: | |
| for d in i[1:]: | |
| d = d.strip() | |
| if not d: | |
| break | |
| d = tuple(map(int,RE_NUM.findall(d))) | |
| dateLi.append(d) | |
| li = dateLi[:] | |
| ymd = dateLi[-1] | |
| for i in range(DEPTH): | |
| ymd = next_quarter(ymd) | |
| li.append(ymd) | |
| out.writerow([""]+list( | |
| map(lambda x:str(date(*x))[:7],li) | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| for i in glob(join(_DIR,"csv/*.csv")): | |
| i = basename(i)[:-4] | |
| print(i) | |
| csv_future(i) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment