遥感大影像位深转8bit的方法及实现

位深

位深(Bit Depth)是指数字图像中每个像素用于记录颜色信息的位数。位深越高,可用于表示颜色的信息就越丰富,图像的颜色就越真实且细腻。然而,位深越高,图像文件的大小也就越大。

在数字图像中,常见的位深有8位、16位、24位和32位等。下面我们来看一下8位和16位位深的区别:

  1. 8位位深:每个像素用8位(即1字节)来记录颜色信息。在灰度图像中,8位位深可以表示256(即$2^8)种不同的灰度等级。在彩色图像中,通常每种颜色(红、绿、蓝)都用8位来记录,因此可以表示256^3(即16777216)种不同的颜色。
  2. 16位位深:每个像素用16位(即2字节)来记录颜色信息。在灰度图像中,16位位深可以表示65536(即$2^{16})种不同的灰度等级。在彩色图像中,通常每种颜色(红、绿、蓝)都用16位来记录,因此可以表示65536^3种不同的颜色。

从这里可以看出,16位位深的图像比8位位深的图像有更丰富的颜色信息,能更精细地表示图像的颜色变化。然而,16位位深的图像文件的大小也会比8位位深的图像文件大一倍。

遥感影像大多数未拉伸前,一般为10\14\16bit。

为什么要做

在深度学习中,数据的表示范围和精度对训练的影响是一个重要的考虑因素。虽然16位(即2字节)的数据可以提供更大的范围和更高的精度,但这也意味着需要更多的计算资源(如内存和计算能力)。在许多情况下,8位(即1字节)的数据已经足够满足深度学习的需求,并且可以显著减少所需的计算资源。

以下是将16位数据转换为8位数据的一些主要原因:

  1. 内存和存储空间:8位数据只需要16位数据一半的内存和存储空间。这对于处理大量数据(如图像或视频)的深度学习任务来说是非常重要的。
  2. 计算速度:处理8位数据的计算速度通常会比处理16位数据更快。这是因为计算机的内存和处理器都是按字节(8位)组织的,处理8位数据可以更有效地利用这些硬件资源。
  3. 兼容性:许多深度学习框架和算法都是针对8位数据进行优化的。使用8位数据可以确保与这些框架和算法的兼容性。
  4. 精度需求:虽然16位数据可以提供更高的精度,但在许多深度学习任务中,这种高精度并不是必需的。实际上,一些研究表明,使用较低精度的数据对深度学习模型的性能影响很小。

总的来说,将16位数据转换为8位数据是一种权衡。这种转换可能会损失一些精度,但可以获得更高的计算效率和更好的兼容性。

我做16位降8位的原因大多数源自于深度学习训练模型需要8bit的需求。

有什么困难

常规的16bit转8bit,是利用拉伸的方式进行。但是大影像进行降位处理会出现内存不足的问题。就是不能把数据一次性储存到内存中。一般会出现以下错误提示:

numpy.core._exceptions.MemoryError: Unable to allocate 4.14 GiB for an array with shape (4, 29648, 38489) and data type uint8

这个错误信息表示程序试图在内存中创建一个非常大的数组,但是可用的内存不足,因此无法完成这个操作。

怎么解决

利用分块的思路,按照规则依次把小分块进行拉伸,然后保存到新的tif中。

其中首先计算了整个图像的最大值和最小值,然后在处理每个分块时使用这些参数。

这样做的好处在于,分块处理拉伸问题,但全局使用同样的参数,使得最后结果无明显的分块痕迹。

具体代码

#!/usr/bin/env python
# -*- coding: utf-8 -*- 
# @Time : 2024/2/27 21:40 
# @File : 16to8.py 

from osgeo import gdal
import numpy as np
import time
import os
from tqdm import tqdm

class ImageCompressor:
    def __init__(self, path):
        self.path = path

    def read_img(self, input_file):
        '''
        这个方法用于打开输入的图像文件,并返回一个包含图像数据和图像的行数、列数、波段数的元组。
        :param input_file: 输入的图像文件
        :return: 包含图像数据和图像的行数、列数、波段数的元组
        '''
        in_ds = gdal.Open(input_file, gdal.GA_ReadOnly)  # 仅打开文件
        rows = in_ds.RasterYSize  # 获取数据高度
        cols = in_ds.RasterXSize  # 获取数据宽度
        bands = in_ds.RasterCount  # 获取数据波段数
        datatype = in_ds.GetRasterBand(1).DataType
        # GDT_Byte = 1, GDT_UInt16 = 2, GDT_UInt32 = 4, GDT_Int32 = 5, GDT_Float32 = 6
        data_type_name = {1: 'GDT_Byte', 2: 'GDT_UInt16',3:'GDT_Int16', 4: 'GDT_UInt32', 5: 'GDT_Int32', 6: 'GDT_Float32'}
        print("数据类型:", data_type_name[datatype])
        array_data = in_ds.ReadAsArray()  # 将数据写成数组,读取全部数据,numpy数组
        del in_ds
        return array_data, rows, cols, bands

    def write_img(self, read_path, img_array):
        '''
        这个方法用于将图像数据写入到一个新的文件中。
        :param read_path: 读取的文件路径
        :param img_array: 图像数据
        :return: None
        '''
        read_pre_dataset = gdal.Open(read_path)
        img_transf = read_pre_dataset.GetGeoTransform()  # 仿射矩阵
        img_proj = read_pre_dataset.GetProjection()  # 地图投影信息
        print("read shape:", img_array.shape, img_array.dtype.name)
        if 'uint8' in img_array.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in img_array.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        if len(img_array.shape) == 3:
            img_bands, im_height, im_width = img_array.shape
        else:
            img_bands, (im_height, im_width) = 1, img_array.shape
        filename = read_path[:-4] + '_unit8' + ".tif"
        driver = gdal.GetDriverByName("GTiff")  # 创建文件驱动
        dataset = driver.Create(filename, im_width, im_height, img_bands, datatype)
        dataset.SetGeoTransform(img_transf)  # 写入仿射变换参数
        dataset.SetProjection(img_proj)  # 写入投影
        if img_bands == 1:
            dataset.GetRasterBand(1).WriteArray(img_array)
        else:
            for i in range(img_bands):
                dataset.GetRasterBand(i + 1).WriteArray(img_array[i])

    def compress(self, origin_16, chunk_size=5000):
        '''
        这个方法是图像压缩的主要部分。它首先读取原图像的数据,然后计算全局的最大值和最小值,然后创建一个新的文件,
        并在每个分块上应用线性拉伸,最后将处理后的分块写入到新文件中。
        :param origin_16: 原图像路径
        :param chunk_size: 分块大小
        :return: None
        '''
        def linear_stretching(img, min, max):
            '''
            @todo 线性拉伸
            @param img: 图像 二维数组
            @return: 拉伸后的图像
            '''
            img = np.array(img)
            img = (img - min) / (max - min) * 255
            return img

        array_data, rows, cols, bands = self.read_img(origin_16)

        global_min = np.min(array_data)
        global_max = np.max(array_data)

        read_pre_dataset = gdal.Open(origin_16)
        img_transf = read_pre_dataset.GetGeoTransform()  # 仿射矩阵
        img_proj = read_pre_dataset.GetProjection()  # 地图投影信息
        filename = origin_16[:-4] + '_unit8' + ".tif"
        driver = gdal.GetDriverByName("GTiff")  # 创建文件驱动
        dataset = driver.Create(filename, cols, rows, bands, gdal.GDT_Byte)
        dataset.SetGeoTransform(img_transf)  # 写入仿射变换参数
        dataset.SetProjection(img_proj)  # 写入投影

        total_chunks = (rows // chunk_size + 1) * (cols // chunk_size + 1) * bands
        with tqdm(total=total_chunks) as pbar:
            for i in range(bands):
                for start_row in range(0, rows, chunk_size):
                    end_row = start_row + chunk_size if start_row + chunk_size < rows else rows
                    for start_col in range(0, cols, chunk_size):
                        end_col = start_col + chunk_size if start_col + chunk_size < cols else cols
                        chunk = array_data[i, start_row:end_row, start_col:end_col]
                        stretched_chunk = linear_stretching(chunk, global_min, global_max)
                        dataset.GetRasterBand(i + 1).WriteArray(stretched_chunk, xoff=start_col, yoff=start_row)
                        pbar.update()
        dataset = None

    def listdir(self, path, list_name):
        '''
        用于获取指定路径下所有符合指定格式的文件的列表
        :param path: 文件路径
        :param list_name: 用于存储文件名的列表
        :return:  返回文件名列表
        '''
        panduan_geshi = [".tif", ".pix", ".img", ]
        for file in os.listdir(path):
            file_path = os.path.join(path, file)
            if os.path.isdir(file_path):
                self.listdir(file_path, list_name)
            elif os.path.isfile(file_path):
                if os.path.splitext(file_path)[1] in panduan_geshi:
                    list_name.append(file_path)
            else:
                print("File Error")

    def isfile(self, path):
        '''
        判断是否是文件
        '''
        if os.path.isfile(path):
            return True
        else:
            return False

    def ispath(self, path):
        '''
        判断是否是路径
        '''
        if os.path.isdir(path):
            return True
        else:
            return False

    def process(self):
        '''
        整个类的入口点。它首先检查给定的路径是文件还是目录,然后对每个文件调用 compress 方法进行处理。
        '''
        if self.ispath(self.path):
            lisst = []
            self.listdir(self.path, lisst)
            print(len(lisst))
            for file_name in lisst:
                print(file_name)
                start = time.time()
                self.compress(file_name)
                print("cost time:", time.time() - start)
        if self.isfile(self.path):
            start = time.time()
            self.compress(self.path)
            print("cost time:", time.time() - start)


if __name__ == '__main__':
    path = r"D:\\GF2.tif"
    compressor = ImageCompressor(path)
    compressor.process()

image-20240227163546474

代码解析

这段代码是一个名为ImageCompressor的类,主要用于处理大型地理信息系统(GIS)图像数据,并将16位数据压缩为8位数据。它使用了gdal库来读取和写入图像,numpy库来处理图像数据,以及tqdm库来显示处理进度。以下是对这段代码主要部分的解释:

  1. __init__函数:初始化函数,接收一个路径参数。
  2. read_img函数:使用gdal库打开指定路径的图像文件,读取图像的大小、波段数和数据类型,并将图像数据读取为numpy数组。
  3. write_img函数:根据输入的图像路径和数组,使用gdal库创建一个新的图像文件,并将数组数据写入新文件。如果输入数组的数据类型是8位无符号整数,新文件的数据类型将设置为GDT_Byte;如果是16位整数,将设置为GDT_UInt16;否则,将设置为GDT_Float32
  4. compress函数:这是主要的图像处理函数。它首先读取原始16位图像数据,然后对图像进行线性拉伸,将数据范围从全局最小值和最大值拉伸到0和255。然后,它将处理后的8位数据写入新的图像文件。这个函数使用了块处理策略,将图像分割成多个小块进行处理,以降低内存使用。
  5. listdir函数:递归地列出指定路径下所有的.tif, .pix, .img文件。
  6. isfileispath函数:检查指定路径是文件还是目录。
  7. process函数:如果输入路径是目录,它将处理目录下的所有图像文件;如果输入路径是文件,它将只处理这个文件。在处理每个文件之前和之后,它都会打印当前时间,以便计算处理时间。

这个类需要在实例化后调用process方法来开始处理图像,例如:

compressor = ImageCompressor("/path/to/images")
compressor.process()

这将处理指定路径下的所有.tif, .pix, .img文件,并将处理后的图像保存为新的.tif文件。

改进

把int16—>int8拓展为任意类型—>int8

把拉伸类型从线性拉伸改为 百分比线性拉伸

具体代码如下:

#!/usr/bin/env Python
# coding=utf-8
# @Time : 2024/2/26 08:40
# @File : 16to8.py 

from osgeo import gdal
import numpy as np
import time
import os
from tqdm import tqdm

class ImageCompressor:
    def __init__(self, path):
        self.path = path

    def read_img(self, input_file):
        '''
        这个方法用于打开输入的图像文件,并返回一个包含图像数据和图像的行数、列数、波段数的元组。
        :param input_file: 输入的图像文件
        :return: 包含图像数据和图像的行数、列数、波段数的元组
        '''
        in_ds = gdal.Open(input_file, gdal.GA_ReadOnly)  # 仅打开文件
        rows = in_ds.RasterYSize  # 获取数据高度
        cols = in_ds.RasterXSize  # 获取数据宽度
        bands = in_ds.RasterCount  # 获取数据波段数
        datatype = in_ds.GetRasterBand(1).DataType
        # GDT_Byte = 1, GDT_UInt16 = 2, GDT_UInt32 = 4, GDT_Int32 = 5, GDT_Float32 = 6
        data_type_name = {1: 'GDT_Byte', 2: 'GDT_UInt16',3:'GDT_Int16', 4: 'GDT_UInt32', 5: 'GDT_Int32', 6: 'GDT_Float32', 7:'GDT_CFloat64'}
        print("数据类型:", data_type_name[datatype])
        array_data = in_ds.ReadAsArray()  # 将数据写成数组,读取全部数据,numpy数组
        del in_ds
        return array_data, rows, cols, bands

    def write_img(self, read_path, img_array):
        '''
        这个方法用于将图像数据写入到一个新的文件中。
        :param read_path: 读取的文件路径
        :param img_array: 图像数据
        :return: None
        '''
        read_pre_dataset = gdal.Open(read_path)
        img_transf = read_pre_dataset.GetGeoTransform()  # 仿射矩阵
        img_proj = read_pre_dataset.GetProjection()  # 地图投影信息
        print("read shape:", img_array.shape, img_array.dtype.name)
        if 'uint8' in img_array.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in img_array.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        if len(img_array.shape) == 3:
            img_bands, im_height, im_width = img_array.shape
        else:
            img_bands, (im_height, im_width) = 1, img_array.shape
        filename = read_path[:-4] + '__unit8' + ".tif"
        driver = gdal.GetDriverByName("GTiff")  # 创建文件驱动
        dataset = driver.Create(filename, im_width, im_height, img_bands, datatype)
        dataset.SetGeoTransform(img_transf)  # 写入仿射变换参数
        dataset.SetProjection(img_proj)  # 写入投影
        # 无效值为0


        if img_bands == 1:
            dataset.GetRasterBand(1).WriteArray(img_array)
        else:
            for i in range(img_bands):
                dataset.GetRasterBand(i + 1).WriteArray(img_array[i])

    def compress(self, origin_16, chunk_size=5000):
        '''
        这个方法是图像压缩的主要部分。它首先读取原图像的数据,然后计算全局的最大值和最小值,然后创建一个新的文件,
        并在每个分块上应用线性拉伸,最后将处理后的分块写入到新文件中。
        :param origin_16: 原图像路径
        :param chunk_size: 分块大小
        :return: None
        '''
        def linear_stretching(img, min, max):
            '''
            @todo 线性拉伸
            @param img: 图像 二维数组
            @return: 拉伸后的图像
            '''
            img = np.array(img)
            img = (img - min) / (max - min) * 255
            return img

        array_data, rows, cols, bands = self.read_img(origin_16)

        global_min = np.min(array_data)
        global_max = np.max(array_data)

        global_min = np.percentile(array_data, 2)
        global_max = np.percentile(array_data, 96)
        if 'float64' in global_min.dtype.name:
            global_min = 0
        read_pre_dataset = gdal.Open(origin_16)
        img_transf = read_pre_dataset.GetGeoTransform()  # 仿射矩阵
        img_proj = read_pre_dataset.GetProjection()  # 地图投影信息
        filename = origin_16[:-4] + '_unit8' + ".tif"
        driver = gdal.GetDriverByName("GTiff")  # 创建文件驱动
        dataset = driver.Create(filename, cols, rows, bands, gdal.GDT_Byte)
        dataset.SetGeoTransform(img_transf)  # 写入仿射变换参数
        dataset.SetProjection(img_proj)  # 写入投影

        total_chunks = (rows // chunk_size + 1) * (cols // chunk_size + 1) * bands
        with tqdm(total=total_chunks) as pbar:
            for i in range(bands):
                for start_row in range(0, rows, chunk_size):
                    end_row = start_row + chunk_size if start_row + chunk_size < rows else rows
                    for start_col in range(0, cols, chunk_size):
                        end_col = start_col + chunk_size if start_col + chunk_size < cols else cols
                        chunk = array_data[i, start_row:end_row, start_col:end_col]
                        stretched_chunk = linear_stretching(chunk, global_min, global_max)
                        stretched_chunk[stretched_chunk == 0 ] = np.nan
                        stretched_chunk_uint8 = stretched_chunk.astype(np.uint8)
                        dataset.GetRasterBand(i + 1).WriteArray(stretched_chunk_uint8, xoff=start_col, yoff=start_row)
                        pbar.update()
        dataset = None

    def listdir(self, path, list_name):
        '''
        用于获取指定路径下所有符合指定格式的文件的列表
        :param path: 文件路径
        :param list_name: 用于存储文件名的列表
        :return:  返回文件名列表
        '''
        panduan_geshi = [".tif", ".pix", ".img", ]
        for file in os.listdir(path):
            file_path = os.path.join(path, file)
            if os.path.isdir(file_path):
                self.listdir(file_path, list_name)
            elif os.path.isfile(file_path):
                if os.path.splitext(file_path)[1] in panduan_geshi:
                    list_name.append(file_path)
            else:
                print("File Error")

    def isfile(self, path):
        '''
        判断是否是文件
        '''
        if os.path.isfile(path):
            return True
        else:
            return False

    def ispath(self, path):
        '''
        判断是否是路径
        '''
        if os.path.isdir(path):
            return True
        else:
            return False

    def process(self):
        '''
        整个类的入口点。它首先检查给定的路径是文件还是目录,然后对每个文件调用 compress 方法进行处理。
        '''
        if self.ispath(self.path):
            lisst = []
            self.listdir(self.path, lisst)
            print(len(lisst))
            for file_name in lisst:
                print(file_name)
                start = time.time()
                self.compress(file_name)
                print("cost time:", time.time() - start)
        if self.isfile(self.path):
            start = time.time()
            self.compress(self.path)
            print("cost time:", time.time() - start)


if __name__ == '__main__':
    path = r"clip3.tif"
    compressor = ImageCompressor(path)
    compressor.process()