1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
""" 解析MNIST数据集的IDX格式文件 """ import scipy.misc import numpy as np import struct import matplotlib.pyplot as plt import os
dataset_path = "/home/ryancrj/data/mnist-dataset/"
train_image_idx_ubyte_file = 'train-images.idx3-ubyte' train_labels_idx_ubyte_file = 'train-labels.idx1-ubyte'
save_train_images_path = "train_images" save_train_labels_file= "train_labels.txt"
test_image_idx_ubyte_file = 't10k-images.idx3-ubyte' test_labels_idx_ubyte_file = 't10k-labels.idx1-ubyte'
save_test_images_path = "test_images" save_test_labels_file = "test_labels.txt"
def decode_idx3_ubyte(idx3_ubyte_file, save_path): ''' 解析idx3文件 :param idx3_ubyte_file: idx3文件路径 :return: 解析得到的数据集 '''
bin_data = open(idx3_ubyte_file, 'rb').read()
offset = 0 fmt_header = '>iiii' magic_number, num_images, num_rows, num_cols = struct.unpack_from( fmt_header, bin_data, offset) print '魔术数: {},图片数量: {},图片大小: {} * {}'.format( magic_number, num_images, num_rows, num_cols)
image_size = num_rows * num_cols offset += struct.calcsize(fmt_header) fmt_image = '>' + str(image_size) + 'B' images = np.empty((num_images, num_rows, num_cols)) for i in range(num_images): if (i + 1) % 10000 == 0: print "已经解析 %d" %(i+1) + " 张" images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset) ).reshape((num_rows, num_cols)) offset += struct.calcsize(fmt_image) scipy.misc.imsave(os.path.join(save_path, '{}.jpg'.format(i+1)), images[i]) return images
def decode_idx1_ubyte(idx1_ubyte_file, save_file): ''' 解析idx1文件 :param idx1_ubyte_file: idx1文件路径 :return: 解析得到的数据集 '''
bin_data = open(idx1_ubyte_file, 'rb').read()
offset = 0 fmt_header = '>ii' magic_number, num_labels = struct.unpack_from(fmt_header, bin_data, offset) print '魔术数: {},标签数量: {}'.format(magic_number, num_labels)
offset += struct.calcsize(fmt_header) fmt_label = '>B' labels = np.empty(num_labels) fout = open(save_file, 'w') for i in range(num_labels): if (i + 1) % 10000 == 0: print "已经解析 %d" %(i+1) + " 个" labels[i] = np.array(struct.unpack_from(fmt_label, bin_data, offset))[0] offset += struct.calcsize(fmt_label) fout.write(str(int(labels[i]))+'\n') return labels
def load_train_images(): save_image_path = os.path.join(dataset_path, save_train_images_path) if not os.path.exists(save_image_path): os.mkdir(save_image_path)
return decode_idx3_ubyte(os.path.join( dataset_path, train_image_idx_ubyte_file), save_image_path)
def load_train_labels(): save_file = os.path.join(dataset_path, save_train_labels_file) return decode_idx1_ubyte(os.path.join( dataset_path, train_labels_idx_ubyte_file), save_file)
def load_test_images(): save_image_path = os.path.join(dataset_path, save_test_images_path) if not os.path.exists(save_image_path): os.mkdir(save_image_path) return decode_idx3_ubyte(os.path.join( dataset_path, test_image_idx_ubyte_file), save_image_path)
def load_test_labels(): save_file = os.path.join(dataset_path, save_test_labels_file) return decode_idx1_ubyte(os.path.join( dataset_path, test_labels_idx_ubyte_file), save_file)
def test():
test_images = load_test_images() test_labels = load_test_labels()
for i in range(10): print test_labels[i] plt.imshow(test_images[i], cmap='gray') plt.show() print 'done'
def parse_data(): train_images = load_train_images() train_labels = load_train_labels() test_images = load_test_images() test_labels = load_test_labels()
if __name__ == '__main__': parse_data()
|