Created
March 19, 2025 07:41
-
-
Save znxkznxk1030/635e327f454451317f91b21f33c9bc3d to your computer and use it in GitHub Desktop.
pytorch 연습
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
# 입력 데이터의 크기: (배치 크기, 채널, 높이, 너비) | |
input_size = (1, 4, 4, 4) | |
# Conv2d 레이어 정의 | |
conv = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=(3,3), stride=1, padding=1, groups=2) | |
# 입력 데이터 생성 | |
input_data = torch.randn(input_size) | |
# 컨볼루션 연산 수행 | |
output = conv(input_data) | |
# 출력 데이터의 크기 출력 | |
print("Input size:", input_data.size()) | |
print(input_data) | |
print("Output size:", output.size()) | |
print(output) | |
import torch.nn as nn | |
class Model(nn.Module): | |
def __init__(self): | |
super(Model, self).__init__() | |
self.fc1 = nn.Linear(100, 50) | |
self.bn = nn.BatchNorm1d(num_features=50) | |
self.relu = nn.ReLU() | |
self.fc2 = nn.Linear(50, 10) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
x = self.fc2(x) | |
return x | |
print(input_data) | |
import torch | |
import torch.nn as nn | |
# 입력 텐서의 크기가 10이고 출력 텐서의 크기가 20인 선형 변환을 수행하는 nn.Linear 모듈 생성 | |
linear = nn.Linear(4, 10) | |
# 입력 텐서 생성 (크기가 10인 벡터) | |
input_tensor = torch.zeros(4) | |
print("Input Tensor Size: ", input_tensor.size()) | |
print(input_tensor) | |
# 선형 변환 수행 (입력 텐서를 출력 텐서로 변환) | |
output_tensor = linear(input_tensor) | |
# print("Input Tensor Size: ", input_tensor.size()) | |
# print(input_tensor) | |
print("Linear Weight", linear.weight) | |
print("Linear Bias", linear.bias) | |
print(sum(linear.bias)) | |
print("Output Tensor Size: ", output_tensor.size()) | |
print(output_tensor) | |
x = torch.randn(1, 2, 3, 4) | |
flatten = nn.Flatten() | |
x_flatten = flatten(x) | |
print(x) | |
print(x_flatten) | |
import torch | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import glob | |
import IPython.display as ipd | |
glob.glob('./**/*mnist*.*') | |
['./sample_data/mnist_test.csv', './sample_data/mnist_train_small.csv'] | |
df = pd.read_csv('./sample_data/mnist_train_small.csv') | |
df_test = pd.read_csv('./sample_data/mnist_test.csv') | |
# print(df.shape, df_test.shape) | |
def imshowTen(imgs,labels, n): | |
fig,axs = plt.subplots(1, n, figsize=(n,2)) | |
for i, ax in enumerate(axs.flat): | |
# print(torch.from_numpy(imgs[i].reshape(-1, 28, 28, 1))) | |
ax.imshow(imgs[i], cmap='gray') | |
ax.set_title(f"{labels[i]}") | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
n = 17 | |
# print(df.iloc[:10]) | |
data = df.iloc[:n, 1:].to_numpy() | |
labels = df.iloc[:n, 0].to_numpy() | |
print(data.shape) | |
imgs = data.reshape(-1,28,28,1) | |
imshowTen(imgs,labels, n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment