Created
August 1, 2018 13:27
-
-
Save prl900/eb078e3437ba3b315d8ad79a11c20f56 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datacube as dc | |
import ctypes | |
from contextlib import closing | |
import numpy as np | |
from datacube.helpers import ga_pq_fuser | |
from datacube.storage import masking | |
import xarray as xr | |
import multiprocessing as mp | |
#N = 1001*1001 | |
def dist_geomedian(params): | |
gmed = np.frombuffer(shared_out_arr.get_obj(), dtype=np.float32).reshape((params[2][0],params[2][2])) | |
X = np.frombuffer(shared_in_arr.get_obj(), dtype=np.int16).reshape(params[2]) | |
for i in range(params[0], params[1]): | |
gmed[:,i] = geometric_median(X[:,:,i], 1, 40) | |
def geometric_median(x, epsilon, max_iter): | |
y0 = np.nanmean(x, axis=1) | |
if len(y0[np.isnan(y0)]) > 0: | |
return y0 | |
for _ in range(max_iter): | |
euc_dist = np.transpose(np.transpose(x) - y0) | |
euc_norm = np.sqrt(np.sum(euc_dist ** 2, axis=0)) | |
not_nan = np.where(~np.isnan(euc_norm))[0] | |
y1 = np.sum(x[:, not_nan] / euc_norm[not_nan], axis=1) / (np.sum(1 / euc_norm[not_nan])) | |
if len(y1[np.isnan(y1)]) > 0 or np.sqrt(np.sum((y1 - y0) ** 2)) < epsilon: | |
return y1 | |
y0 = y1 | |
return y0 | |
class BurnCube(dc.Datacube): | |
def __init__(self): | |
super(BurnCube, self).__init__(app='TreeMapping.getLandsatStack') | |
self.dataset = None | |
self.geomed = None | |
def to_netcdf(self, path): | |
self.dataset.to_netcdf(path) | |
def open_dataset(self, path): | |
self.dataset = xr.open_dataset(path) | |
def _load_pq(self, x, y, res, period, n_landsat): | |
query = { | |
'time': period, | |
'x': x, | |
'y': y, | |
'crs':'EPSG:3577', | |
'measurements' : ['pixelquality'], | |
'resolution': res, | |
} | |
pq_stack = [] | |
for n in n_landsat: | |
pq_stack.append(self.load(product='ls{}_pq_albers'.format(n), | |
group_by='solar_day', fuse_func=ga_pq_fuser, | |
resampling='nearest', **query)) | |
pq_stack = xr.concat(pq_stack, dim='time').sortby('time') | |
pq_stack['land'] = masking.make_mask(pq_stack.pixelquality, land_sea='land') | |
pq_stack['no_cloud'] = masking.make_mask(pq_stack.pixelquality, cloud_acca='no_cloud', | |
cloud_fmask='no_cloud', cloud_shadow_acca='no_cloud_shadow', | |
cloud_shadow_fmask='no_cloud_shadow') | |
return pq_stack | |
def _load_nbart(self, x, y, res, period, n_landsat): | |
query = { | |
'time': period, | |
'x': x, | |
'y': y, | |
'crs': 'EPSG:3577', | |
'measurements' : ['red','green','blue','nir','swir1','swir2'], | |
'resolution': res, | |
} | |
nbart_stack = [] | |
for n in n_landsat: | |
dss = self.find_datasets(product='ls{}_nbart_albers'.format(n), **query) | |
nbart_stack.append(self.load(product='ls{}_nbart_albers'.format(n), | |
group_by='solar_day', datasets=dss, resampling='bilinear', | |
**query)) | |
nbart_stack = xr.concat(nbart_stack, dim='time').sortby('time') | |
return nbart_stack | |
def load_cube(self, x, y, res, period, n_landsat): | |
nbart_stack = self._load_nbart(x, y, res, period, n_landsat) | |
pq_stack = self._load_pq(x, y, res, period, n_landsat) | |
pq_stack, nbart_stack = xr.align(pq_stack, nbart_stack, join='inner') | |
pq_stack['good_pixel'] = pq_stack.no_cloud.where(nbart_stack.red > 0, False, drop=False) | |
goodpix = pq_stack.no_cloud * (pq_stack.pixelquality > 0) * pq_stack.good_pixel | |
print(pq_stack.no_cloud.shape) | |
print(pq_stack.good_pixel) | |
print(goodpix) | |
mask = np.nanmean(goodpix.values.reshape(goodpix.shape[0], -1), axis=1) > .2 | |
cubes = [nbart_stack[band][mask, :, :]*goodpix[mask, :, :] for band in ['red','green','blue','nir','swir1','swir2']] | |
X = np.stack(cubes, axis=0) | |
print(X.shape) | |
#data = xr.Dataset(coords={'band': ['red','green','blue','nir','swir1','swir2'], | |
data = xr.Dataset(coords={'band': np.arange(6), | |
'time':nbart_stack.time[mask], | |
'y':nbart_stack.y[:], | |
'x':nbart_stack.x[:]}, | |
attrs={'crs':'EPSG:3577'}) | |
data["cube"] = (('band','time','y','x'),X) | |
data.time.attrs=[] | |
self.dataset = data | |
def geomedian(self, period, n_procs=4, epsilon=.5, max_iter=40): | |
# Define an output queue | |
#output = mp.Queue() | |
""" | |
X = np.empty((len(names),len(self.dataset.time),len(self.dataset.x)*len(self.dataset.y))) | |
for i, name in enumerate(names): | |
X[i,:,:] = self.dataset[name].data.reshape(len(self.dataset.time), -1) | |
""" | |
n = len(self.dataset.y)*len(self.dataset.x) | |
out_arr = mp.Array(ctypes.c_float, len(self.dataset.band)*n) | |
gmed = np.frombuffer(out_arr.get_obj(), dtype=np.float32).reshape((len(self.dataset.band),n)) | |
gmed.fill(np.nan) | |
_X = self.dataset['cube'].sel(time=slice(period[0],period[1])) | |
print(_X.shape) | |
t_dim = _X.time[:] | |
in_arr = mp.Array(ctypes.c_short, len(self.dataset.band)*len(_X.time)*n) | |
X = np.frombuffer(in_arr.get_obj(), dtype=np.int16).reshape(len(self.dataset.band), len(_X.time), n) | |
X[:] = _X.data.reshape(len(self.dataset.band), len(_X.time), -1) | |
# X[X<=0] = np.nan | |
def init(shared_in_arr_, shared_out_arr_): | |
global shared_in_arr | |
global shared_out_arr | |
shared_in_arr = shared_in_arr_ # must be inherited, not passed as an argument | |
shared_out_arr = shared_out_arr_ # must be inherited, not passed as an argument | |
# write to arr from different processes | |
with closing(mp.Pool(initializer=init, initargs=(in_arr, out_arr,))) as p: | |
# many processes access different slices of the same array | |
chunk = n//n_procs | |
#p.map_async(g, [(i, i + step) for i in range(0, N, step)]) | |
p.map_async(dist_geomedian, [(i, min(n, i+chunk), X.shape) for i in range(0, n, chunk)]) | |
p.join() | |
print(gmed) | |
ds = xr.Dataset(coords={'time':t_dim,'y':self.dataset.y[:],'x':self.dataset.x[:],'bands':self.dataset.band}, attrs={'crs':'EPSG:3577'}) | |
ds['geomedian'] = (('bands','y','x'),gmed[:].reshape((len(self.dataset.band),len(self.dataset.y),len(self.dataset.x))).astype(np.float32)) | |
self.geomed = ds | |
#module use /g/data/v10/public/modules/modulefiles | |
#module load agdc-py3-prod | |
import pyproj | |
wgs84 = pyproj.Proj(init='epsg:4326') | |
gda94 = pyproj.Proj(init='epsg:3577') | |
easting,northing = pyproj.transform(wgs84,gda94,150.4696025,-22.5578545) | |
y = (northing+12500,northing-12500) | |
res = (25, 25) | |
reftime = ('2013-01-01','2017-12-31') # period used for the calculation of geometric median | |
period = ('2013-01-01','2016-12-31') # period used for the calculation of geometric median | |
bc = BurnCube() | |
#bc.load_cube(x, y, res, reftime, [8]) | |
#bc.to_netcdf("/g/data/xc0/project/Burn_Mapping/bc_cube_test.nc") | |
bc.open_dataset("/g/data/xc0/project/Burn_Mapping/bc_cube_test.nc") | |
#print(bc.dataset) | |
#print(bc) | |
import time | |
start = time.time() | |
bc.geomedian(period, n_procs=8) | |
print(bc.geomed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment