Skip to content

Instantly share code, notes, and snippets.

@zew13
Created June 8, 2020 13:12
Show Gist options
  • Select an option

  • Save zew13/7ed9470c3137f773b0b934a7750204f9 to your computer and use it in GitHub Desktop.

Select an option

Save zew13/7ed9470c3137f773b0b934a7750204f9 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from fbprophet import Prophet
from glob import glob
from pandas import Series,DataFrame
from os.path import join,dirname,basename
from datetime import date
from collections import defaultdict
import pandas as pd
import csv
import re
RE_NUM = re.compile("\\d+")
def next_quarter(numLi):
m = numLi[1]+3
y = numLi[0]
if m > 12:
m = m % 12
y += 1
return (y,m,1)
DEPTH = 3 * 4
# 日期 = "2016-12-31 2017-09-30 2017-12-31 2018-09-30 2018-12-31 2019-03-31 2019-06-30 2019-09-30 2019-12-31 2020-03-31".split()
# 数据 = dict(
# 佣金收入 = "132.01 327.29 524.43 632.04 708.20 635.48 677.21 624.44 732.67 1427.30",
# 金融服务收入 = "3.27 32.28 82.89 157.42 171.93 208.61 191.14 189.89 203.04 163.62",
# 利息收入 = "1.08 5.29 51.88 259.84 505.59 833.21 477.30",
# 其他收入 = "1.62 2.45 1.52 10.98 32.14 64.03 162.08 299.10 228.09 250.56"
# )
#
# 预测 = defaultdict(list)
def extend(ds, li):
df = DataFrame({
"ds":list(map(lambda x:str(date(*x)),ds)),
"y":li
})
m = Prophet(
weekly_seasonality=False,
n_changepoints=len(li)//12,
daily_seasonality=False,
seasonality_prior_scale=.5,
changepoint_range=1
)
m.fit(df)
future = m.make_future_dataframe(periods=DEPTH,freq="Q",include_history=False)
forecast = m.predict(future)
future.tail()
# r = forecast[['ds','yhat', 'yhat_lower', 'yhat_upper']]
# print(list(forecast.yhat))
return list(forecast.yhat)
#
# # fig1 = m.plot(forecast)
# # fig2 = m.plot_components(forecast)
#
# # from statsmodels.tsa.api import Holt
# # import numpy as np
# # import matplotlib.pyplot as plt
# #
# # fit = Holt(np.asarray(income)).fit(smoothing_level=1.406, smoothing_slope=0.2)
# #
# # for i in fit.forecast(4):
# # print(round(i,2))
#
# # plt.figure()
# # plt.plot(income[:-1]+list(fit.forecast(5)))
# # plt.show()
_DIR = dirname(__file__)
def csv_future(csvname):
with open(join(_DIR,"csv",csvname+'.csv'),encoding="utf-8-sig") as f:
with open(join(_DIR,"future",csvname+'.csv'),"w", encoding="utf-8-sig") as out:
out = csv.writer(out)
f = csv.reader(f)
dateLi = []
for pos,i in enumerate(f):
if pos:
key = i[0].strip()
# if key!="经纪佣金":
# continue
# if key!="薪酬及福利费用":
# continue
if not key:
continue
ds_li = []
val_li = []
row_li = []
org_li = []
for d,val in zip(dateLi, i[1:]):
val = val.strip()
if val:
val = val.replace(',','')
val = float(val)
org_li.append(val)
month = d[1]
quarter = month//3
if not ds_li or (ds_li[-1][0] != d[0]):
val = val/quarter
else:
val = (val-org_li[-2]) /((d[1] - ds_li[-1][1])//3)
row_li.append(str(round(val,4)))
ds_li.append(d)
val_li.append(val)
else:
row_li.append("")
# print(val_li)
# print(ds_li)
out.writerow(
[key]+row_li+list(map(lambda x:str(round(x,4)),extend(ds_li,val_li)))
)
else:
for d in i[1:]:
d = d.strip()
if not d:
break
d = tuple(map(int,RE_NUM.findall(d)))
dateLi.append(d)
li = dateLi[:]
ymd = dateLi[-1]
for i in range(DEPTH):
ymd = next_quarter(ymd)
li.append(ymd)
out.writerow([""]+list(
map(lambda x:str(date(*x))[:7],li)
)
)
if __name__ == "__main__":
for i in glob(join(_DIR,"csv/*.csv")):
i = basename(i)[:-4]
print(i)
csv_future(i)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment