gdal api笔记

​ 如果你的gdal版本是3.2.2及以上,那么只能:

from osgeo import gdal

WBtMsU.png


ds = gdal.Open(data)
rows = ds.RasterYSize
cols = ds.RasterXSize
bandnum = ds.RasterCount
transform = ds.GetGeoTransform()

ds是一个对象,rows是影像的行,也是Y轴长度。对应的,cols则是影像的列,X轴的长度。bandnum代表影像的波段数。

transform是一个list,存储着栅格数据集的地理坐标信息。

#transform[0] /* top left x 左上角x坐标(经度)*/
#transform[1] /* w--e pixel resolution 东西方向上的像素分辨率*/
#transform[2] /* rotation, 0 if image is "north up" 如果北边朝上,地图的旋转角度*/
#transform[3] /* top left y 左上角y坐标(纬度)*/
#transform[4] /* rotation, 0 if image is "north up" 如果北边朝上,地图的旋转角度*/
#transform[5] /* n-s pixel resolution 南北方向上的像素分辨率*/

transform = ds.GetGeoTransform()
originX = transform[0]
originY = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
# 查看坐标系1
wkt = dataset.GetProjection()

# 查看坐标系2
proj = osr.SpatialReference(wkt=ds1.GetProjection())
space = proj.GetAttrValue('AUTHORITY', 1)
print(space)

矩阵操作
np.where(data==-9999,0,data)
python读取图像并显示
# -*- coding: utf-8 -*-
import numpy as np
import sys
from osgeo import gdal
from osgeo.gdalconst import GA_ReadOnly
import matplotlib.pyplot as plt

def disp ( infile , bandnumber ):
    gdal.AllRegister ()
    
    # 以只读的形式,打开影像
    inDataset = gdal.Open ( infile , GA_ReadOnly )
    cols = inDataset.RasterXSize
    rows = inDataset.RasterYSize
    bands = inDataset.RasterCount

    image = np . zeros (( bands , rows , cols ))
    for b in range( bands ):
        band = inDataset.GetRasterBand (b + 1)
        image [b ,: ,:]= band.ReadAsArray (0 ,0 , cols , rows )
    # 关闭数据集
    inDataset = None
    
    #  显示图像的某个波段
    band = image [ bandnumber -1 ,: ,:]
    mn = np.amin ( band )
    mx = np.amax ( band )
    plt.imshow (( band - mn )/( mx - mn ) , cmap ='gray')
    plt.show ()
    
if __name__ == '__main__':
    infile = sys.argv [1]
    bandnumber = int( sys.argv [2])
    disp ( infile , bandnumber )

使用方法:

1.把以上代码保存为show_image.py

2打开脚本路径的cmd命令行,运行

python show_image.py xxx.tif 4

xxx.tif 是待读取影像的绝对路径, 4 代表待显示的波段

示例如下:

avatar


多波段

如果自己做的图包含多个波段(往往大于4个),Opencv或PIL就不太顶用了,这时候GDAL就派上用场了
例如我有一个十波段图像,用此函数读取后为numpy数组类,shape为[h,w,10]

from osgeo import gdal
import numpy as np

def load_img(path):
    dataset = gdal.Open(path)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)
    im_data = im_data.transpose((1,2,0)) #此步保证矩阵为channel_last模式
    return im_data

读取全部波段

from osgeo     import gdal
ds = gdal.Open(file)
all_band_data = ds.ReadAsArray()

保存影像

GDT_Byte、GDT_UInt16、GDT_Float32

from osgeo import gdal
import numpy as np
import os
def waterExtractOHS(file):
    inDataset = gdal.Open(file)
    transform = inDataset.GetGeoTransform()
    driver = inDataset.GetDriver()
    originX = transform[0]
    originY = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    cols = inDataset.RasterXSize
    rows = inDataset.RasterYSize
    bands = inDataset.RasterCount
    image = np . zeros (( bands , rows , cols ))
    # for b in range( bands ):
    #     band = inDataset.GetRasterBand (b + 1)
    #     image [b ,: ,:]= band.ReadAsArray (0 ,0 , cols , rows )
    print('读取波段')
    blue = inDataset.GetRasterBand (2).ReadAsArray (0 ,0 , cols , rows )
    green = inDataset.GetRasterBand(7).ReadAsArray(0, 0, cols, rows)
    red = inDataset.GetRasterBand(14).ReadAsArray(0, 0, cols, rows)
    nir = inDataset.GetRasterBand(28).ReadAsArray(0, 0, cols, rows)
    print('计算指数')
    nmbwi =(3*green-blue+2*red-5*nir)/(3*green+blue+2*red+5*nir)

    outname = os.path.splitext(file)[0] + '_nmbwi.tif'
    dataset = driver.Create(outname, cols, rows, 1, gdal.GDT_Float32)
    outBand = dataset.GetRasterBand(1)
    outBand.WriteArray(nmbwi, 0, 0)

if __name__ == '__main__':
    file = r'BandStack_6s_orth.tif'
    waterExtractOHS(file)

将两景配准好的影像重采样到同一分辨率

def resample(sourceDataset, dstDataset, outname, interp):
        '''
        参数:
        srcImageFilename (str):源(低分辨率)多光谱Geotiff文件名。
        sourceDataset (osgeo.gdal.Dataset):输入多光谱GDAL dataset对象。
        dstDataset (osgeo.gdal.Dataset):目标(高分辨率)全色数据集对象。
        outname (str):重新采样后输出的Geotiff的名称
        interp (int): GDAL插值方法(即gdalconstt . gra_cubic)

        resampledMultispectralGeotiffFilename = os.path.splitext(nightlightfile)[0]+'nightlight.tif'
        resample(nightlightset, ohsset, resampledMultispectralGeotiffFilename,
             gdalconst.GRA_Bilinear)
        '''
        print('resample image')
        # get the "source" (i.e. low-res. multispectral) projection and geotransform
        srcProjection = sourceDataset.GetProjection()
        srcGeotransform = sourceDataset.GetGeoTransform()
        srcNumRasters = sourceDataset.RasterCount
        dstProjection = dstDataset.GetProjection()
        dstGeotransform = dstDataset.GetGeoTransform()
        nrows = dstDataset.RasterYSize
        ncols = dstDataset.RasterXSize
        dst_fn = outname

        # if the resampled-multispectral (3 or 4 band) Geotiff image file exists, delete it.
        if not os.path.isfile(outname):
            dst_ds = gdal.GetDriverByName('GTiff').Create(dst_fn, ncols, nrows, srcNumRasters, gdalconst.GDT_Float32)
            dst_ds.SetGeoTransform(dstGeotransform)
            dst_ds.SetProjection(dstProjection)
            gdal.ReprojectImage(sourceDataset, dst_ds, srcProjection, dstProjection, interp)
            dst_ds = None
            del dst_ds
        print('完成resample')
        return dst_fn

数据类型
GDAL中的GDALDataType是一个枚举型,其中的值为:

GDT_Unknown : 未知数据类型
GDT_Byte : 8bit正整型 (C++中对应unsigned char)
GDT_UInt16 : 16bit正整型 (C++中对应 unsigned short)
GDT_Int16 : 16bit整型 (C++中对应 short 或 short int)
GDT_UInt32 : 32bit 正整型 (C++中对应unsigned long)
GDT_Int32 : 32bit整型 (C++中对应int 或 long 或 long int)
GDT_Float32 : 32bit 浮点型 (C++中对应float)
GDT_Float64 : 64bit 浮点型 (C++中对应double)
GDT_CInt16 : 16bit复整型 (?)
GDT_CInt32 : 32bit复整型 (?)
GDT_CFloat32 : 32bit复浮点型 (?)
GDT_CFloat64 : 64bit复浮点型 (?)

shp相关

分块

ds = gdal.Open(file)
info = DatasetInfo(file)
rows = ds.RasterYSize  # todo  图像宽度
cols =ds.RasterXSize
# 分块
nBlockSize = 500
i = 0
j = 0
try:
    while i < rows:
        while j < cols:
            nXBK = nBlockSize
            nYBK = nBlockSize

            # 最后不够分块的区域,有多少读取多少
            if i + nBlockSize > rows:
                nYBK = rows - i
            if j + nBlockSize > cols:
                nXBK = cols - j

            j = j + nXBK
        j = 0
        i = i + nYBK
except KeyboardInterrupt:
    raise

读取金字塔

from osgeo import gdal
ovrfile = r'xxx.ovr'
dsOVR = gdal.OpenEx(file, open_options=['OVERVIEW_LEVEL=0'])
OVRdata1 = dsOVR.GetRasterBand(1).ReadAsArray()
OVRdata2 = dsOVR.GetRasterBand(2).ReadAsArray()
OVRdata3 = dsOVR.GetRasterBand(3).ReadAsArray()

读取shp范围栅格数据

import fiona
import rasterio
from rasterio.mask import mask
import matplotlib.pyplot as plt
import numpy as np
import time


'''
读取shp范围栅格数据
'''
def read_raster_shp(shpfile, tiffile):
    with fiona.open(shpfile, "r") as shapefile:
        geoms = [feature["geometry"] for feature in shapefile]
    with rasterio.open(tiffile) as src:
        out_image, out_transform = mask(src, geoms, invert=False)
    return out_image
    
    
if __name__ == '__main__':
    a = time.time()

    tiffile = r"_clip.tif"
    shpfile = r'New_Shapefile(2).shp'

    out_image = read_raster_shp(shpfile, tiffile)
    z,x,y = np.shape(out_image)
    img_arr = np.zeros (shape=(x, y, 3))
    for i in range(3):
        img_arr[:,:,i] = out_image[i,:,:] / np.nanmax(out_image[i,:,:])
    plt.imshow(img_arr),plt.show()

    b = time.time()
    print(b-a)
    a = 0

单GCP点粗校正

针对RPC校正后仍出现严重的几何错位问题

获取源影像的空间分辨率、投影坐标
提前获知源影像单个GCP点的经度、纬度、高度、列数、行数
具体代码如下:



from osgeo import gdal, osr
import os, shutil

if __name__ == "__main__":
    google_tif = r'D:\xxx\OVS\out.tif'
    tif = r'D:\xxx\OVS\L1\VDM1_20210927224960_0008_L1_MSS_CMOS2\VDM1_RPC.tif'
    OutputFilePath = r'D:\xxx\OVS'
    ds = gdal.Open(tif)
    geotrans = ds.GetGeoTransform()  # todo 仿射矩阵
    pixelWidth = geotrans[1]  # todo x轴空间间隔
    pixelHeight = geotrans[5]
    del geotrans
    new_file = os.path.join(OutputFilePath, os.path.basename(tif))
    shutil.copyfile(tif, new_file)
    print('原始待校正影像复制到输出路径:', new_file)


    dataset = gdal.Open(new_file, gdal.GA_Update)

    sr = osr.SpatialReference()
    sr.SetWellKnownGeogCS('EPSG:4326')
    # 添加控制点
    gcp_list = [gdal.GCP(132.49583, 34.23777777, 0, 2940, 8746),
                gdal.GCP(132.49583-pixelWidth*2940, 34.23777777+pixelWidth*8746, 0, 0, 0),
                gdal.GCP(132.49583, 34.23777777+pixelWidth*8746, 0, 2940, 0),
                gdal.GCP(132.49583-pixelWidth*2940, 34.23777777, 0, 0, 8746)]
    dataset.SetGCPs(gcp_list, sr.ExportToWkt())
    out_filename = os.path.splitext(tif)[0] + '_sift_gcp.tiff'
    out_filename = os.path.join(OutputFilePath, os.path.basename(out_filename))


    # out_filename2 = os.path.join(outputpath, out_filename2)
    gdal.Warp(out_filename, dataset, tps=True, xRes=pixelWidth,
              yRes=pixelWidth, srcNodata=-9999,
              resampleAlg=gdal.GRIORA_Bilinear, outputType=gdal.GDT_Float32)

后续需写GUI,实现获知源影像单个GCP点的经度、纬度、高度、列数、行数

获取影像无效值

from osgeo import gdal
file = ''
ds = gdal.Open(file)
nodata = ds.GetRasterBnad(1).GetNoDataValue()

从DEM获取指定经纬度的高度

from osgeo import gdal
import numpy as np
from registration import  ImageRegister
def fromDemFileGetHeight(dem, lon, lat):
    try:
        DEMIDataSet = gdal.Open(dem)
    except Exception as e:
        print('Missing DEM file')
        pass
    DEMBand = DEMIDataSet.GetRasterBand(1)
    geotransform = DEMIDataSet.GetGeoTransform()
    # DEM分辨率
    pixelWidth = geotransform[1]
    pixelHight = geotransform[5]
    # DEM起始点:左上角,X:经度,Y:纬度
    originX = geotransform[0]
    originY = geotransform[3]

    # 研究区左上角在DEM矩阵中的位置
    yoffset1 = int((originY - lat) / pixelWidth)
    xoffset1 = int((lon - originX) / (-pixelHight))

    DEMRasterData = DEMBand.ReadAsArray(xoffset1, yoffset1, 1, 1)
    DEMRasterData = np.mean(DEMRasterData)
    return DEMRasterData
from osgeo import gdal
import numpy as np
from registration import  ImageRegister
def fromDemFileGetHeight(dem, lon, lat):
    try:
        DEMIDataSet = gdal.Open(dem)
    except Exception as e:
        print('Missing DEM file')
        pass
    DEMBand = DEMIDataSet.GetRasterBand(1)
    geotransform = DEMIDataSet.GetGeoTransform()
    # DEM分辨率
    pixelWidth = geotransform[1]
    pixelHight = geotransform[5]
    # DEM起始点:左上角,X:经度,Y:纬度
    originX = geotransform[0]
    originY = geotransform[3]

    # 研究区左上角在DEM矩阵中的位置
    yoffset1 = int((originY - lat) / pixelWidth)
    xoffset1 = int((lon - originX) / (-pixelHight))

    DEMRasterData = DEMBand.ReadAsArray(xoffset1, yoffset1, 1, 1)
    DEMRasterData = np.mean(DEMRasterData)
    return DEMRasterData

普通图片

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from PIL import Image
import numpy as np

I = Image.open(r'C:\Users\Administrator\Pictures\1.jpg')
I.show()
I.save('./save.png')  # 保存为新的图片
I_array = np.array(I)  # 转为numpy类
x, y, z = I_array.shape  # 普通图像的长、宽、波段
with open(r'C:\Users\Administrator\Pictures\save.txt', 'w',encoding="utf8") as f:
    for band in range(z):
        f.write('我是第%d波段的数据'%band)
        for column in range(x):
            for line in range(y):
                print("%d波段的%d行%d列像素是:%d" % (band, line, column, I_array[column, line, band]))
                data = str(I_array[column, line, band])
                f.write(data)
                f.write(' ')
            f.write('/n')
        f.write('/n')
        f.write('/n')

为栅格数据添加地面控制点

import shutil
from osgeo import gdal
orig_fn = r''
shutil.copy(orig_fn, fn)   #因为要更新,所以需要对文件做个备份
ds = gdal.Open(fn, gdal.GA_Update)
sr = osr.SpatialReference()
sr = osr.SetWellKnownGeogCS('WGS84')  #WGS84基准
gcps = [gdal.GCP(-111.93, 41.74, 0, 1078, 648),
        gdal.GCP(-111.90, 41.75, 0, 3531, 295)]

ds.SetGCPS(gcps, sr.ExportToWkt())  #将地面控制点附加到数据集
ds = None


#如果要使用一阶变换来创建地理变换,将地面控制点列表传递给GCPsToGeotransform,确保在数据集上设置了地理变换和投影信息

ds.SetProjection(sr.ExportToWkt())

ds.SetGeoTransform(gdal.GCPsToGeotransform(gcps)) #从地面控制点中创建地理变换,并设置在数据集上

几何校正

from osgeo import gdal

def get_indices(source_ds, target_width, target_height):
    source_geotransform = source_ds.GetGeoTransform()
    source_width = source_ds.GetGeoTransform[1]
    source_height = source_ds.GetGeoTransform[5]
    dx = target_width/source_width
    dy = target_height/source_height
    target_x = np.arange(dx/2, source_ds.RasterXSize, dx)
    target_y = np.arange(dy/2, source_ds.RasterYSize, dy)
    return np.meshgrid(target_x, target_y)

def bilinear(in_data, x, y):
    x-=0.5
    y-=0.5
    x0=np.floor(x).astype(int)
    x1=x0+1
    y0=np.floor(y).astype(int)
    y1=y0+1
    
    ul=in_data[y0,x0]*(y1-y)*(x1-x)
    ur=in_data[y0,x1]*(y1-y)*(x-x0)
    ll=in_data[y1,x0]*(y-y0)*(x1-x)
    lr=in_data[y1,x1]*(y-y0)*(x-x0)

    return ul+ur+ll+lr

if __name__ =="__main__":
    in_fn=r''
    out_fn=r''
    cell_size=(0.02,-0.02)

    in_ds = gdal.Open(in_fn)
    x, y=get_indices(in_ds,*cell_size)
    outdata=bilinear(in_ds.ReadAsArray(), x, y)
    
    driver=gdal.GetDriverByName('GTiff')
    rows,cols=outdata.shape
    out_ds=driver.Create(out_fn,cols,rows,1,gdal.GDT_ Int32)

    gt=list(in_ds.GetGeoTransform())
    gt[1]=cell_size[0]
    gt[5]=cell_size[1]
    out_ds.SetGeoTransform(gt)
    
    out_band = out_ds.GetRasterBand(1)
    out_band.WriteArray(outdata)
    out_band.FlushCache()
    out_band.ComputeStatistics(False)