You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
267 lines
12 KiB
Plaintext
267 lines
12 KiB
Plaintext
4 years ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"ename": "ModuleNotFoundError",
|
||
|
"evalue": "No module named 'numpy'",
|
||
|
"output_type": "error",
|
||
|
"traceback": [
|
||
|
"\u001b[0;31m--------------------------------------------------------\u001b[0m",
|
||
|
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||
|
"\u001b[0;32m<ipython-input-2-d9bbc8b73862>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcdefaults\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlines\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLine2D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'numpy'"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"plt.rcdefaults()\n",
|
||
|
"from matplotlib.lines import Line2D\n",
|
||
|
"from matplotlib.patches import Rectangle\n",
|
||
|
"from matplotlib.patches import Circle\n",
|
||
|
"\n",
|
||
|
"NumDots = 4\n",
|
||
|
"NumConvMax = 8\n",
|
||
|
"NumFcMax = 20\n",
|
||
|
"White = 1.\n",
|
||
|
"Light = 0.7\n",
|
||
|
"Medium = 0.5\n",
|
||
|
"Dark = 0.3\n",
|
||
|
"Darker = 0.15\n",
|
||
|
"Black = 0.\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def add_layer(patches, colors, size=(24, 24), num=5,\n",
|
||
|
" top_left=[0, 0],\n",
|
||
|
" loc_diff=[3, -3],\n",
|
||
|
" ):\n",
|
||
|
" # add a rectangle\n",
|
||
|
" top_left = np.array(top_left)\n",
|
||
|
" loc_diff = np.array(loc_diff)\n",
|
||
|
" loc_start = top_left - np.array([0, size[0]])\n",
|
||
|
" for ind in range(num):\n",
|
||
|
" patches.append(Rectangle(loc_start + ind * loc_diff, size[1], size[0]))\n",
|
||
|
" if ind % 2:\n",
|
||
|
" colors.append(Medium)\n",
|
||
|
" else:\n",
|
||
|
" colors.append(Light)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def add_layer_with_omission(patches, colors, size=(24, 24),\n",
|
||
|
" num=5, num_max=8,\n",
|
||
|
" num_dots=4,\n",
|
||
|
" top_left=[0, 0],\n",
|
||
|
" loc_diff=[3, -3],\n",
|
||
|
" ):\n",
|
||
|
" # add a rectangle\n",
|
||
|
" top_left = np.array(top_left)\n",
|
||
|
" loc_diff = np.array(loc_diff)\n",
|
||
|
" loc_start = top_left - np.array([0, size[0]])\n",
|
||
|
" this_num = min(num, num_max)\n",
|
||
|
" start_omit = (this_num - num_dots) // 2\n",
|
||
|
" end_omit = this_num - start_omit\n",
|
||
|
" start_omit -= 1\n",
|
||
|
" for ind in range(this_num):\n",
|
||
|
" if (num > num_max) and (start_omit < ind < end_omit):\n",
|
||
|
" omit = True\n",
|
||
|
" else:\n",
|
||
|
" omit = False\n",
|
||
|
"\n",
|
||
|
" if omit:\n",
|
||
|
" patches.append(\n",
|
||
|
" Circle(loc_start + ind * loc_diff + np.array(size) / 2, 0.5))\n",
|
||
|
" else:\n",
|
||
|
" patches.append(Rectangle(loc_start + ind * loc_diff,\n",
|
||
|
" size[1], size[0]))\n",
|
||
|
"\n",
|
||
|
" if omit:\n",
|
||
|
" colors.append(Black)\n",
|
||
|
" elif ind % 2:\n",
|
||
|
" colors.append(Medium)\n",
|
||
|
" else:\n",
|
||
|
" colors.append(Light)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def add_mapping(patches, colors, start_ratio, end_ratio, patch_size, ind_bgn,\n",
|
||
|
" top_left_list, loc_diff_list, num_show_list, size_list):\n",
|
||
|
"\n",
|
||
|
" start_loc = top_left_list[ind_bgn] \\\n",
|
||
|
" + (num_show_list[ind_bgn] - 1) * np.array(loc_diff_list[ind_bgn]) \\\n",
|
||
|
" + np.array([start_ratio[0] * (size_list[ind_bgn][1] - patch_size[1]),\n",
|
||
|
" - start_ratio[1] * (size_list[ind_bgn][0] - patch_size[0])]\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" end_loc = top_left_list[ind_bgn + 1] \\\n",
|
||
|
" + (num_show_list[ind_bgn + 1] - 1) * np.array(\n",
|
||
|
" loc_diff_list[ind_bgn + 1]) \\\n",
|
||
|
" + np.array([end_ratio[0] * size_list[ind_bgn + 1][1],\n",
|
||
|
" - end_ratio[1] * size_list[ind_bgn + 1][0]])\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" patches.append(Rectangle(start_loc, patch_size[1], -patch_size[0]))\n",
|
||
|
" colors.append(Dark)\n",
|
||
|
" patches.append(Line2D([start_loc[0], end_loc[0]],\n",
|
||
|
" [start_loc[1], end_loc[1]]))\n",
|
||
|
" colors.append(Darker)\n",
|
||
|
" patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],\n",
|
||
|
" [start_loc[1], end_loc[1]]))\n",
|
||
|
" colors.append(Darker)\n",
|
||
|
" patches.append(Line2D([start_loc[0], end_loc[0]],\n",
|
||
|
" [start_loc[1] - patch_size[0], end_loc[1]]))\n",
|
||
|
" colors.append(Darker)\n",
|
||
|
" patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],\n",
|
||
|
" [start_loc[1] - patch_size[0], end_loc[1]]))\n",
|
||
|
" colors.append(Darker)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def label(xy, text, xy_off=[0, 4]):\n",
|
||
|
" plt.text(xy[0] + xy_off[0], xy[1] + xy_off[1], text,\n",
|
||
|
" family='sans-serif', size=8)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"if __name__ == '__main__':\n",
|
||
|
"\n",
|
||
|
" fc_unit_size = 2\n",
|
||
|
" layer_width = 40\n",
|
||
|
" flag_omit = True\n",
|
||
|
"\n",
|
||
|
" patches = []\n",
|
||
|
" colors = []\n",
|
||
|
"\n",
|
||
|
" fig, ax = plt.subplots()\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" ############################\n",
|
||
|
" # conv layers\n",
|
||
|
" size_list = [(32, 32), (18, 18), (10, 10), (6, 6), (4, 4)]\n",
|
||
|
" num_list = [3, 32, 32, 48, 48]\n",
|
||
|
" x_diff_list = [0, layer_width, layer_width, layer_width, layer_width]\n",
|
||
|
" text_list = ['Inputs'] + ['Feature\\nmaps'] * (len(size_list) - 1)\n",
|
||
|
" loc_diff_list = [[3, -3]] * len(size_list)\n",
|
||
|
"\n",
|
||
|
" num_show_list = list(map(min, num_list, [NumConvMax] * len(num_list)))\n",
|
||
|
" top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]\n",
|
||
|
"\n",
|
||
|
" for ind in range(len(size_list)-1,-1,-1):\n",
|
||
|
" if flag_omit:\n",
|
||
|
" add_layer_with_omission(patches, colors, size=size_list[ind],\n",
|
||
|
" num=num_list[ind],\n",
|
||
|
" num_max=NumConvMax,\n",
|
||
|
" num_dots=NumDots,\n",
|
||
|
" top_left=top_left_list[ind],\n",
|
||
|
" loc_diff=loc_diff_list[ind])\n",
|
||
|
" else:\n",
|
||
|
" add_layer(patches, colors, size=size_list[ind],\n",
|
||
|
" num=num_show_list[ind],\n",
|
||
|
" top_left=top_left_list[ind], loc_diff=loc_diff_list[ind])\n",
|
||
|
" label(top_left_list[ind], text_list[ind] + '\\n{}@{}x{}'.format(\n",
|
||
|
" num_list[ind], size_list[ind][0], size_list[ind][1]))\n",
|
||
|
"\n",
|
||
|
" ############################\n",
|
||
|
" # in between layers\n",
|
||
|
" start_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]\n",
|
||
|
" end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]\n",
|
||
|
" patch_size_list = [(5, 5), (2, 2), (5, 5), (2, 2)]\n",
|
||
|
" ind_bgn_list = range(len(patch_size_list))\n",
|
||
|
" text_list = ['Convolution', 'Max-pooling', 'Convolution', 'Max-pooling']\n",
|
||
|
"\n",
|
||
|
" for ind in range(len(patch_size_list)):\n",
|
||
|
" add_mapping(\n",
|
||
|
" patches, colors, start_ratio_list[ind], end_ratio_list[ind],\n",
|
||
|
" patch_size_list[ind], ind,\n",
|
||
|
" top_left_list, loc_diff_list, num_show_list, size_list)\n",
|
||
|
" label(top_left_list[ind], text_list[ind] + '\\n{}x{} kernel'.format(\n",
|
||
|
" patch_size_list[ind][0], patch_size_list[ind][1]), xy_off=[26, -65]\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" ############################\n",
|
||
|
" # fully connected layers\n",
|
||
|
" size_list = [(fc_unit_size, fc_unit_size)] * 3\n",
|
||
|
" num_list = [768, 500, 2]\n",
|
||
|
" num_show_list = list(map(min, num_list, [NumFcMax] * len(num_list)))\n",
|
||
|
" x_diff_list = [sum(x_diff_list) + layer_width, layer_width, layer_width]\n",
|
||
|
" top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]\n",
|
||
|
" loc_diff_list = [[fc_unit_size, -fc_unit_size]] * len(top_left_list)\n",
|
||
|
" text_list = ['Hidden\\nunits'] * (len(size_list) - 1) + ['Outputs']\n",
|
||
|
"\n",
|
||
|
" for ind in range(len(size_list)):\n",
|
||
|
" if flag_omit:\n",
|
||
|
" add_layer_with_omission(patches, colors, size=size_list[ind],\n",
|
||
|
" num=num_list[ind],\n",
|
||
|
" num_max=NumFcMax,\n",
|
||
|
" num_dots=NumDots,\n",
|
||
|
" top_left=top_left_list[ind],\n",
|
||
|
" loc_diff=loc_diff_list[ind])\n",
|
||
|
" else:\n",
|
||
|
" add_layer(patches, colors, size=size_list[ind],\n",
|
||
|
" num=num_show_list[ind],\n",
|
||
|
" top_left=top_left_list[ind],\n",
|
||
|
" loc_diff=loc_diff_list[ind])\n",
|
||
|
" label(top_left_list[ind], text_list[ind] + '\\n{}'.format(\n",
|
||
|
" num_list[ind]))\n",
|
||
|
"\n",
|
||
|
" text_list = ['Flatten\\n', 'Fully\\nconnected', 'Fully\\nconnected']\n",
|
||
|
"\n",
|
||
|
" for ind in range(len(size_list)):\n",
|
||
|
" label(top_left_list[ind], text_list[ind], xy_off=[-10, -65])\n",
|
||
|
"\n",
|
||
|
" ############################\n",
|
||
|
" for patch, color in zip(patches, colors):\n",
|
||
|
" patch.set_color(color * np.ones(3))\n",
|
||
|
" if isinstance(patch, Line2D):\n",
|
||
|
" ax.add_line(patch)\n",
|
||
|
" else:\n",
|
||
|
" patch.set_edgecolor(Black * np.ones(3))\n",
|
||
|
" ax.add_patch(patch)\n",
|
||
|
"\n",
|
||
|
" plt.tight_layout()\n",
|
||
|
" plt.axis('equal')\n",
|
||
|
" plt.axis('off')\n",
|
||
|
" plt.show()\n",
|
||
|
" fig.set_size_inches(8, 2.5)\n",
|
||
|
"\n",
|
||
|
" # fig_dir = './'\n",
|
||
|
" # fig_ext = '.png'\n",
|
||
|
" # fig.savefig(os.path.join(fig_dir, 'convnet_fig' + fig_ext),\n",
|
||
|
" # bbox_inches='tight', pad_inches=0)\n",
|
||
|
"\n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.5"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|