Created
July 26, 2018 07:38
-
-
Save a-maumau/6d0a50ec15be89851e9fc65629111d07 to your computer and use it in GitHub Desktop.
parse SpaceNet dataset.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
### geopandas must be imported earlier than osgeo ### | |
import geopandas as gpd | |
from osgeo import gdal, ogr, osr | |
import numpy as np | |
import scipy | |
from scipy.misc import bytescale | |
import osmnx | |
import cv2 | |
import skimage | |
from skimage import exposure | |
from PIL import Image | |
import matplotlib | |
import matplotlib.pyplot as plt | |
# nvidia stretch code | |
def retrieve_bands(ds, x_size, y_size, bands): | |
stack = np.zeros([x_size, y_size, len(bands)]) | |
for i, band in enumerate(bands): | |
src_band = ds.GetRasterBand(band) | |
band_arr = src_band.ReadAsArray() | |
stack[:, :, i] = band_arr | |
return stack | |
def contrast_stretch(np_img, p1_clip=2, p2_clip=90): | |
x, y, bands = np_img.shape | |
return_stack = np.zeros([x, y, bands], dtype=np.uint8) | |
for b in range(bands): | |
cur_b = np_img[:, :, b] | |
p1_pix, p2_pix = np.percentile(cur_b, (p1_clip, p2_clip)) | |
return_stack[:, :, b] = bytescale(exposure.rescale_intensity(cur_b, out_range=(p1_pix, p2_pix))) | |
return return_stack | |
# below codes are from https://github.com/CosmiQ/apls/blob/master/src/apls_tools.py | |
def create_buffer_geopandas(geoJsonFileName, | |
bufferDistanceMeters=2, | |
bufferRoundness=1, projectToUTM=True): | |
''' | |
Create a buffer around the lines of the geojson. | |
Return a geodataframe. | |
''' | |
inGDF = gpd.read_file(geoJsonFileName) | |
# set a few columns that we will need later | |
inGDF['type'] = inGDF['road_type'].values | |
inGDF['class'] = 'highway' | |
inGDF['highway'] = 'highway' | |
if len(inGDF) == 0: | |
return [], [] | |
# Transform gdf Roadlines into UTM so that Buffer makes sense | |
if projectToUTM: | |
tmpGDF = osmnx.project_gdf(inGDF) | |
else: | |
tmpGDF = inGDF | |
gdf_utm_buffer = tmpGDF | |
# perform Buffer to produce polygons from Line Segments | |
gdf_utm_buffer['geometry'] = tmpGDF.buffer(bufferDistanceMeters,bufferRoundness) | |
gdf_utm_dissolve = gdf_utm_buffer.dissolve(by='class') | |
gdf_utm_dissolve.crs = gdf_utm_buffer.crs | |
if projectToUTM: | |
gdf_buffer = gdf_utm_dissolve.to_crs(inGDF.crs) | |
else: | |
gdf_buffer = gdf_utm_dissolve | |
return gdf_buffer | |
def gdf_to_array(gdf, im_file, output_raster, burnValue=150): | |
''' | |
Turn geodataframe to array, save as image file with non-null pixels | |
set to burnValue | |
''' | |
NoData_value = 0 # -9999 | |
gdata = gdal.Open(im_file) | |
# set target info | |
target_ds = gdal.GetDriverByName('GTiff').Create(output_raster, | |
gdata.RasterXSize, | |
gdata.RasterYSize, 1, gdal.GDT_Byte) | |
target_ds.SetGeoTransform(gdata.GetGeoTransform()) | |
# set raster info | |
raster_srs = osr.SpatialReference() | |
raster_srs.ImportFromWkt(gdata.GetProjectionRef()) | |
target_ds.SetProjection(raster_srs.ExportToWkt()) | |
band = target_ds.GetRasterBand(1) | |
band.SetNoDataValue(NoData_value) | |
outdriver=ogr.GetDriverByName('MEMORY') | |
outDataSource=outdriver.CreateDataSource('memData') | |
tmp=outdriver.Open('memData',1) | |
outLayer = outDataSource.CreateLayer("states_extent", raster_srs, | |
geom_type=ogr.wkbMultiPolygon) | |
# burn | |
burnField = "burn" | |
idField = ogr.FieldDefn(burnField, ogr.OFTInteger) | |
outLayer.CreateField(idField) | |
featureDefn = outLayer.GetLayerDefn() | |
for geomShape in gdf['geometry'].values: | |
outFeature = ogr.Feature(featureDefn) | |
outFeature.SetGeometry(ogr.CreateGeometryFromWkt(geomShape.wkt)) | |
outFeature.SetField(burnField, burnValue) | |
outLayer.CreateFeature(outFeature) | |
outFeature = 0 | |
gdal.RasterizeLayer(target_ds, [1], outLayer, burn_values=[burnValue]) | |
outLayer = 0 | |
outDatSource = 0 | |
tmp = 0 | |
return | |
def get_road_buffer(geoJson, im_vis_file, output_raster, | |
buffer_meters=2, burnValue=1, bufferRoundness=6, | |
plot_file='', figsize=(6,6), fontsize=6, | |
dpi=800, show_plot=False, | |
verbose=False): | |
''' | |
Get buffer around roads defined by geojson and image files. | |
Calls create_buffer_geopandas() and gdf_to_array(). | |
Assumes in_vis_file is an 8-bit RGB file. | |
Returns geodataframe and ouptut mask. | |
''' | |
gdf_buffer = create_buffer_geopandas(geoJson, | |
bufferDistanceMeters=buffer_meters, | |
bufferRoundness=bufferRoundness, | |
projectToUTM=True) | |
# create label image | |
if len(gdf_buffer) == 0: | |
mask_gray = np.zeros(cv2.imread(im_vis_file,0).shape) | |
cv2.imwrite(output_raster, mask_gray) | |
else: | |
gdf_to_array(gdf_buffer, im_vis_file, output_raster, | |
burnValue=burnValue) | |
# load mask | |
mask_gray = cv2.imread(output_raster, 0) | |
# make plots | |
if plot_file: | |
# plot all in a line | |
if (figsize[0] != figsize[1]): | |
fig, (ax0, ax1, ax2, ax3) = plt.subplots(1,4, figsize=figsize)#(13,4)) | |
# else, plot a 2 x 2 grid | |
else: | |
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2,2, figsize=figsize) | |
# road lines | |
try: | |
gdfRoadLines = gpd.read_file(geoJson) | |
gdfRoadLines.plot(ax=ax0, marker='o', color='red') | |
except: | |
ax0.imshow(mask_gray) | |
ax0.axis('off') | |
ax0.set_aspect('equal') | |
ax0.set_title('Roads from GeoJson', fontsize=fontsize) | |
# first show raw image | |
im_vis = cv2.imread(im_vis_file, 1) | |
img_mpl = cv2.cvtColor(im_vis, cv2.COLOR_BGR2RGB) | |
ax1.imshow(img_mpl) | |
ax1.axis('off') | |
ax1.set_title('8-bit RGB Image', fontsize=fontsize) | |
# plot mask | |
ax2.imshow(mask_gray) | |
ax2.axis('off') | |
ax2.set_title('Roads Mask (' + str(np.round(buffer_meters)) \ | |
+ ' meter buffer)', fontsize=fontsize) | |
# plot combined | |
ax3.imshow(img_mpl) | |
# overlay mask | |
# set zeros to nan | |
z = mask_gray.astype(float) | |
z[z==0] = np.nan | |
# change palette to orange | |
palette = plt.cm.gray | |
#palette.set_over('yellow', 0.9) | |
palette.set_over('lime', 0.9) | |
ax3.imshow(z, cmap=palette, alpha=0.66, | |
norm=matplotlib.colors.Normalize(vmin=0.5, vmax=0.9, clip=False)) | |
ax3.set_title('8-bit RGB Image + Buffered Roads', fontsize=fontsize) | |
ax3.axis('off') | |
#plt.axes().set_aspect('equal', 'datalim') | |
plt.tight_layout() | |
plt.savefig(plot_file, dpi=dpi) | |
if not show_plot: | |
plt.close() | |
return mask_gray, gdf_buffer | |
min_percent = 5 | |
max_percent = 90 | |
mask_output_dir = "masks" | |
jpg_output_dir = "jpgs" | |
plot_file = os.path.join('mask_plot.png') | |
# plase rewrite to be compatible with your dataset directory. | |
tif_image_dir = "AOI_2_Vegas_Roads_Train/RGB-PanSharpen" | |
geojson_file_dir = "AOI_2_Vegas_Roads_Train/geojson/spacenetroads" | |
if not os.path.exists(mask_output_dir): | |
os.makedirs(mask_output_dir) | |
if not os.path.exists(jpg_output_dir): | |
os.makedirs(jpg_output_dir) | |
image_list = os.listdir(tif_image_dir) | |
for name in image_list: | |
try: | |
print(name) | |
image_name = os.path.join(tif_image_dir, name) | |
# plase rewrite to be compatible with your dataset. | |
geojson_file = os.path.join(geojson_file_dir, name.replace(".tif", ".geojson").replace("RGB-PanSharpen_AOI_2_Vegas_img", "spacenetroads_AOI_2_Vegas_img")) | |
jpg_name = os.path.join(jpg_output_dir, name.replace(".tif", ".jpg")) | |
mask_raster = os.path.join(mask_output_dir, name.replace(".tif", ".png")) | |
# in the `get_road_buffer` function, it save images, so actually we don't need mask and gdf_buffer | |
mask, gdf_buffer = get_road_buffer(geojson_file, image_name, | |
mask_raster, | |
#buffer_meters=2, | |
#burnValue=1, | |
bufferRoundness=6, | |
plot_file=plot_file, | |
figsize= (6,6), #(13,4), | |
fontsize=8, | |
dpi=200, show_plot=False, | |
verbose=False) | |
ds = gdal.Open(image_name) | |
# channel w x h x channel | |
channel = np.array([ds.GetRasterBand(1).ReadAsArray(), ds.GetRasterBand(2).ReadAsArray(), ds.GetRasterBand(2).ReadAsArray()]).transpose(1,2,0) | |
band = contrast_stretch(channel, min_percent, max_percent) | |
# little bit change | |
#band[:,:,0] = band[:,:,0] | |
#band[band[:,:,1]>1] = band[band[:,:,1]>1]-2 | |
#band[band[:,:,2]>9] = band[band[:,:,2]>9]-10 | |
img = Image.fromarray(np.uint8(band)) | |
img.save(jpg_name, quality=100) | |
except: | |
print("skip {}".format(image_name)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment