Created
October 1, 2019 10:42
-
-
Save akatasonov/cb682ff5a064e7b3cbd4223c8fbcaeeb to your computer and use it in GitHub Desktop.
Sentinel 2A/2B band extraction and cropping using rasterio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import pyproj | |
import shapely | |
import shapely.geometry | |
import shapely.ops | |
import fiona | |
import rasterio | |
import rasterio.mask | |
import rasterio.merge | |
import numpy | |
import pickle | |
def project_wsg_shape_to_csr(shape, from_crs, to_crs): | |
project = lambda x, y: pyproj.transform( | |
from_crs, | |
to_crs, | |
x, | |
y | |
) | |
return shapely.ops.transform(project, shape) | |
train_shapefile = fiona.open("train/train.shp", "r") | |
train_shape_crs = pyproj.Proj(train_shapefile.crs) | |
test_shapefile = fiona.open("test/test.shp", "r") | |
test_shape_crs = pyproj.Proj(test_shapefile.crs) | |
#print(shapefile.crs) | |
# Start by enumerating SAFE products | |
# TODO: check cloud contamination using s2cloudless | |
product_groups = {} | |
train_field_data = {} | |
train_field_data_r = {} | |
train_field_data_g = {} | |
train_field_data_b = {} | |
test_field_data = {} | |
test_field_data_r = {} | |
test_field_data_g = {} | |
test_field_data_b = {} | |
for product_fn in glob.glob('*.SAFE'): | |
#print(product_fn) | |
""" | |
The compact naming convention is arranged as follows: | |
MMM_MSIL1C_YYYYMMDDHHMMSS_Nxxyy_ROOO_Txxxxx_<Product Discriminator>.SAFE | |
The products contain two dates. | |
The first date (YYYYMMDDHHMMSS) is the datatake sensing time. | |
The second date is the "<Product Discriminator>" field, which is 15 characters in length, and is used to distinguish between different end user products from the same datatake. Depending on the instance, the time in this field can be earlier or slightly later than the datatake sensing time. | |
The other components of the filename are: | |
MMM: is the mission ID(S2A/S2B) | |
MSIL1C: denotes the Level-1C product level | |
YYYYMMDDHHMMSS: the datatake sensing start time | |
Nxxyy: the Processing Baseline number (e.g. N0204) | |
ROOO: Relative Orbit number (R001 - R143) | |
Txxxxx: Tile Number field | |
SAFE: Product Format (Standard Archive Format for Europe) | |
""" | |
# Split the product name into parts | |
product_attrs = product_fn.split('_') | |
datatake_time = product_attrs[2] | |
tile_number = product_attrs[5] | |
# Since the shape files provided cover two tiles, group tiles by datatake_time | |
if datatake_time in product_groups: | |
product_groups[datatake_time].append(product_fn) | |
else: | |
product_groups[datatake_time] = [product_fn] | |
# sort the dict in the chronological order | |
product_groups = dict(sorted(product_groups.items())) | |
# Enumerate groups of tiles | |
for product_group in product_groups: | |
print('*** Processing {}..'.format(product_group)) | |
b2 = [] # all B4 bands for a group, blue | |
b3 = [] # all B4 bands for a group, green | |
b4 = [] # all B4 bands for a group, red | |
b8 = [] # all B8 bands for a group | |
for product_fn in product_groups[product_group]: | |
print(' {}'.format(product_fn)) | |
b2fn = '' | |
b3fn = '' | |
b4fn = '' | |
b8fn = '' | |
for bandfn in glob.glob('{}/GRANULE/*/IMG_DATA/*.jp2'.format(product_fn)): | |
# Split the band file name | |
base = os.path.basename(bandfn) | |
band_attrs = os.path.splitext(base)[0].split('_') | |
band_type = band_attrs[2] # B01, B02, etc | |
if band_type == 'B02': | |
b2fn = bandfn | |
if band_type == 'B03': | |
b3fn = bandfn | |
if band_type == 'B04': | |
b4fn = bandfn | |
if band_type == 'B08': | |
b8fn = bandfn | |
assert b4fn and b8fn # should have both values | |
b2.append(rasterio.open(b2fn)) | |
b3.append(rasterio.open(b3fn)) | |
b4.append(rasterio.open(b4fn)) | |
b8.append(rasterio.open(b8fn)) | |
print(' Merging bands..') | |
# For a group of tiles/products, merge bands from different tiles together | |
blue, _ = rasterio.merge.merge(b2) | |
green, _ = rasterio.merge.merge(b3) | |
red, out_trans = rasterio.merge.merge(b4) | |
nir, _ = rasterio.merge.merge(b8) | |
# Calculate the NDVI, given B4 and B8 band filenames | |
print(' Calculating the NDVI..') | |
ndvi = (nir.astype(float) - red.astype(float)) / (nir + red) | |
# Save the NDVI image for manual analysis later | |
print(' Saving the NDVI raster to ndvi/{}.tif..'.format(product_group)) | |
meta = b4[0].meta.copy() | |
meta.update(dtype=rasterio.float64, | |
compress='lzw', | |
driver='GTiff', | |
transform=out_trans, | |
height=red.shape[1], | |
width=red.shape[2] | |
) | |
with rasterio.open('ndvi/{}.tif'.format(product_group), 'w', **meta) as dst: | |
dst.write(ndvi) | |
dst.close() | |
# convert 0..255 range in r,g,b to 0..1 | |
red = red.astype(float) / 65535 | |
green = green.astype(float) / 65535 | |
blue = blue.astype(float) / 65535 | |
# Save red, green and blue images as well | |
print(' Saving the RGB raster to rgb/{}-r/g/b.tif..'.format(product_group)) | |
with rasterio.open('rgb/{}-r.tif'.format(product_group), 'w', **meta) as dst: | |
dst.write(red) | |
dst.close() | |
with rasterio.open('rgb/{}-g.tif'.format(product_group), 'w', **meta) as dst: | |
dst.write(green) | |
dst.close() | |
with rasterio.open('rgb/{}-b.tif'.format(product_group), 'w', **meta) as dst: | |
dst.write(blue) | |
dst.close() | |
ndvi_img = rasterio.open('ndvi/{}.tif'.format(product_group)) | |
#print(' NDVI CRS is', ndvi_img.crs.data) | |
ndvi_crs = pyproj.Proj(ndvi_img.crs) | |
red_img = rasterio.open('rgb/{}-r.tif'.format(product_group)) | |
red_crs = pyproj.Proj(red_img.crs) | |
green_img = rasterio.open('rgb/{}-g.tif'.format(product_group)) | |
green_crs = pyproj.Proj(green_img.crs) | |
blue_img = rasterio.open('rgb/{}-b.tif'.format(product_group)) | |
blue_crs = pyproj.Proj(blue_img.crs) | |
# Alright, NDVI is ready for the whole region in question | |
# Use the shape file to mask out everything, except fields | |
for field in train_shapefile: | |
#print(field['properties']['Field_Id'], field['properties']['Crop_Id_Ne']) | |
field_id = field['properties']['Field_Id'] | |
#print(' Cropping NDVI data for train field #{}'.format(field_id)) | |
try: | |
projected_shape = project_wsg_shape_to_csr(shapely.geometry.shape(field['geometry']), | |
train_shape_crs, | |
ndvi_crs) | |
except Exception as e: | |
print(' ', e, ' exception for field #', field_id) | |
continue | |
#print(projected_shape) | |
field_img, field_img_transform = rasterio.mask.mask(ndvi_img, [projected_shape], crop=True) | |
field_img_red, _ = rasterio.mask.mask(red_img, [projected_shape], crop=True) | |
field_img_green, _ = rasterio.mask.mask(green_img, [projected_shape], crop=True) | |
field_img_blue, _ = rasterio.mask.mask(blue_img, [projected_shape], crop=True) | |
# remove the first dimension | |
field_img = numpy.squeeze(field_img, axis=0) | |
field_img_red = numpy.squeeze(field_img_red, axis=0) | |
field_img_green = numpy.squeeze(field_img_green, axis=0) | |
field_img_blue = numpy.squeeze(field_img_blue, axis=0) | |
# add the 3rd dimension | |
field_img = numpy.expand_dims(field_img, 2) | |
field_img_red = numpy.expand_dims(field_img_red, 2) | |
field_img_green = numpy.expand_dims(field_img_green, 2) | |
field_img_blue = numpy.expand_dims(field_img_blue, 2) | |
if field_id in train_field_data: | |
train_field_data[field_id] = numpy.concatenate((train_field_data[field_id], field_img), axis=2) | |
train_field_data_r[field_id] = numpy.concatenate((train_field_data_r[field_id], field_img_red), axis=2) | |
train_field_data_g[field_id] = numpy.concatenate((train_field_data_g[field_id], field_img_green), axis=2) | |
train_field_data_b[field_id] = numpy.concatenate((train_field_data_b[field_id], field_img_blue), axis=2) | |
else: | |
train_field_data[field_id] = field_img | |
train_field_data_r[field_id] = field_img_red | |
train_field_data_g[field_id] = field_img_green | |
train_field_data_b[field_id] = field_img_blue | |
for field in test_shapefile: | |
#print(field['properties']['Field_Id'], field['properties']['Crop_Id_Ne']) | |
field_id = field['properties']['Field_Id'] | |
#print(' Cropping NDVI data for test field #{}'.format(field_id)) | |
try: | |
projected_shape = project_wsg_shape_to_csr(shapely.geometry.shape(field['geometry']), | |
test_shape_crs, | |
ndvi_crs) | |
except Exception as e: | |
print(' ', e, ' exception for field #', field_id) | |
continue | |
#print(projected_shape) | |
field_img, field_img_transform = rasterio.mask.mask(ndvi_img, [projected_shape], crop=True) | |
field_img_red, _ = rasterio.mask.mask(red_img, [projected_shape], crop=True) | |
field_img_green, _ = rasterio.mask.mask(green_img, [projected_shape], crop=True) | |
field_img_blue, _ = rasterio.mask.mask(blue_img, [projected_shape], crop=True) | |
# remove the first dimension | |
field_img = numpy.squeeze(field_img, axis=0) | |
field_img_red = numpy.squeeze(field_img_red, axis=0) | |
field_img_green = numpy.squeeze(field_img_green, axis=0) | |
field_img_blue = numpy.squeeze(field_img_blue, axis=0) | |
# add the 3rd dimension | |
field_img = numpy.expand_dims(field_img, 2) | |
field_img_red = numpy.expand_dims(field_img_red, 2) | |
field_img_green = numpy.expand_dims(field_img_green, 2) | |
field_img_blue = numpy.expand_dims(field_img_blue, 2) | |
if field_id in test_field_data: | |
test_field_data[field_id] = numpy.concatenate((test_field_data[field_id], field_img), axis=2) | |
test_field_data_r[field_id] = numpy.concatenate((test_field_data_r[field_id], field_img_red), axis=2) | |
test_field_data_g[field_id] = numpy.concatenate((test_field_data_g[field_id], field_img_green), axis=2) | |
test_field_data_b[field_id] = numpy.concatenate((test_field_data_b[field_id], field_img_blue), axis=2) | |
else: | |
test_field_data[field_id] = field_img | |
test_field_data_r[field_id] = field_img_red | |
test_field_data_g[field_id] = field_img_green | |
test_field_data_b[field_id] = field_img_blue | |
# save the fields data to file | |
pickle.dump(train_field_data, open('train/train.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) | |
pickle.dump(train_field_data_r, open('train/train-r.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) | |
pickle.dump(train_field_data_g, open('train/train-g.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) | |
pickle.dump(train_field_data_b, open('train/train-b.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) | |
pickle.dump(test_field_data, open('test/test.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) | |
pickle.dump(test_field_data_r, open('test/test-r.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) | |
pickle.dump(test_field_data_g, open('test/test-g.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) | |
pickle.dump(test_field_data_b, open('test/test-b.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment