Mat处理第二弹

代码功能: 读取高光谱mat文件,按照领域窗口进行分割,制作数据集和标签集

Ps:是分割······


import numpy as np
import scipy.io as scio
import matplotlib.pyplot as plt
import math
from sklearn.cross_validation import train_test_split
import operator

#用于将一维索引转换成二维索引,制作label图
def chage_to_matrix_index(num,height,weight):
temp_index = num
if int(temp_index) % int(height) != 0:
temp_col = math.floor(temp_index / height)
temp_row = int(math.fmod(temp_index, height) - 1)
else:
temp_col = math.floor(temp_index / height)-1
temp_row = int(height - 1)
return temp_row,temp_col

#裁剪后的数据块的标签,去对应的label小块图,得出种类
def matrix_make_label(dataset,Y_cell,X_cell):
classCount={}
for i in range(Y_cell):
for j in range(X_cell):
tt = dataset[i][j]
if tt not in classCount.keys(): classCount[ tt ] = 0
classCount[ tt ] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

#1-只有0  2-有0最多,且两种即以上 3-0不是最多的
if (sortedClassCount[0][0] ==0)and(len(sortedClassCount)==1):
    return 999                        #代表不应该参与分类
elif (sortedClassCount[0][0] ==0)and(len(sortedClassCount)!=1):
    return sortedClassCount[1][0]
else:
    return sortedClassCount[0][0] 

def get_files():
data_mat = scio.loadmat(“Salinas.mat”)
img = np.array(data_mat[‘img’]) # read data from dict
GroundT = np.array(data_mat[‘GroundT’]).T # read data from dict(54129, 2)
[height,weight,channels] = img.shape
print(height,weight,channels)#512 217 204
img_GroundT = np.zeros([height,weight])

#print(GroundT.shape[0])


#制作label构成的图


for t in range(GroundT.shape[0]):
    i,j = chage_to_matrix_index(GroundT[t][0],height,weight)#取出索引,转换为下标
    img_GroundT[i][j] = GroundT[t][1]

#print(GroundT[0][0],GroundT[0][1])   33  15
#print(chage_to_matrix_index(GroundT[0][0],row,col))  (32, 0)
#print(img_GroundT[32][0])   15.0
#矩阵元素[row,col]
#=====================================label图制作完成=====================
X_cell = 29;
Y_cell = 29;
label_list = []
train_list = []

#判断划分后的块label值有没有用,标记999
img_useful = np.zeros([math.floor(height/Y_cell),math.floor(weight/X_cell)])

#先裁剪出label图,先裁行再裁列
for i in range(math.floor(height/Y_cell)):  #(i+1)*29 -1 不会超height
    start = i*Y_cell  #i为row-1    0*29     1*29
    last = (i+1)*Y_cell #          1*29     2*29 
    dataCol_h = img_GroundT[start:last,:]
    #Path_ = strcat(path,'_Label_','row',num2str(row));
    for j in range(math.floor(weight/X_cell)):  #(i+1)*29 -1 不会超weight
        start = j*X_cell  #j为col-1     0*29     1*29
        last = (j+1)*X_cell #           1*29    2*29  
        dataCol_x = dataCol_h[:,start:last]
        #=============================
        #对矩阵操作法①转为one-hot想家多数表决,法②转为字典,封装为函数
        #print(dataCol_x.shape)
        temp = matrix_make_label(dataCol_x,Y_cell,X_cell)
        #最后列表append的应该是个label值
        #===============
        label_list.append(temp)
        if temp == 999:
            img_useful[i][j] = 999

#print('img_useful',img_useful)
#print(np.array(label_list).shape)
#print(label_list[87])
#ok~但此时label还有999
for i in range(math.floor(height/Y_cell)):  #(i+1)*29 -1 不会超height
    start = i*Y_cell  #i为row-1     0*29     1*29
    stop = (i+1)*Y_cell-1 #         1*29 -1  2*29 -1
    dataCol_h = img[start:stop,:,:]
    #Path_ = strcat(path,'_Label_','row',num2str(row));
    for j in range(math.floor(weight/X_cell)):  #(i+1)*29 -1 不会超weight
        start = j*X_cell  #j为col-1     0*29     1*29
        stop = (j+1)*X_cell-1 #         1*29 -1  2*29 -1
        #==================================================
        #检测label值是否为0,为0就不加这个train块,总共7*17(119)块里有些全是0的····
        if img_useful[i][j] == 999:
            continue
        #==================================================
        dataCol_x = dataCol_h[:,start:stop,:]
        train_list.append(dataCol_x)

#print(np.array(train_list).shape)
#print(train_list[0].shape)
#print(train_list[0][0][0])
#print(train_list[0])
#ok~
#=========================================
#剔除完train块里的0的,在剔除label的999的
#逻辑:label和train都是遍历切行再切列,顺序一致
#区别只是label在序列上多余了999,所以直接删除值为999的即可
re_label_list = []
for value in label_list:
    if value != 999:
        re_label_list.append(value)
#print("len----label_list",len(label_list))
#print("len----re_label_list",len(re_label_list))#=========剔除完也是92个
#print(label_list)
#print(re_label_list)

#====================制作完train_list和label_list====================
return train_list,re_label_list

if name == ‘main‘:
train_list,re_label_list = get_files()
print(np.array(train_list).shape)
print(np.array(re_label_list).shape)