Created
October 19, 2020 02:20
-
-
Save tuduweb/489570b1958b091b843012c1ad0ffcec to your computer and use it in GitHub Desktop.
20201019
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Default ignored files | |
/shelf/ | |
/workspace.xml | |
# Datasource local storage ignored files | |
/../../../../:\project\ML-CL-DDoS\.idea/dataSources/ | |
/dataSources.local.xml | |
# Editor-based HTTP Client requests | |
/httpRequests/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<component name="InspectionProjectProfileManager"> | |
<settings> | |
<option name="USE_PROJECT_PROFILE" value="false" /> | |
<version value="1.0" /> | |
</settings> | |
</component> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<project version="4"> | |
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (base)" project-jdk-type="Python SDK" /> | |
</project> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<module type="PYTHON_MODULE" version="4"> | |
<component name="NewModuleRootManager"> | |
<content url="file://$MODULE_DIR$" /> | |
<orderEntry type="inheritedJdk" /> | |
<orderEntry type="sourceFolder" forTests="false" /> | |
</component> | |
<component name="PyDocumentationSettings"> | |
<option name="format" value="GOOGLE" /> | |
<option name="myDocStringFormat" value="Google" /> | |
</component> | |
<component name="TestRunnerService"> | |
<option name="PROJECT_TEST_RUNNER" value="pytest" /> | |
</component> | |
</module> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<project version="4"> | |
<component name="ProjectModuleManager"> | |
<modules> | |
<module fileurl="file://$PROJECT_DIR$/.idea/ML-CL-DDoS.iml" filepath="$PROJECT_DIR$/.idea/ML-CL-DDoS.iml" /> | |
</modules> | |
</component> | |
</project> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<project version="4"> | |
<component name="ChangeListManager"> | |
<list default="true" id="cc085207-44ad-4a53-9562-ba3ad78d023c" name="Default Changelist" comment="" /> | |
<option name="SHOW_DIALOG" value="false" /> | |
<option name="HIGHLIGHT_CONFLICTS" value="true" /> | |
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" /> | |
<option name="LAST_RESOLUTION" value="IGNORE" /> | |
</component> | |
<component name="JupyterTrust" id="552d16af-463a-43d3-bb47-b09f93920a31" /> | |
<component name="ProjectId" id="1j4qWBVhGIXNrt9A5yP4MoMXGgD" /> | |
<component name="ProjectViewState"> | |
<option name="hideEmptyMiddlePackages" value="true" /> | |
<option name="showLibraryContents" value="true" /> | |
</component> | |
<component name="PropertiesComponent"> | |
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" /> | |
<property name="WebServerToolWindowFactoryState" value="false" /> | |
<property name="last_opened_file_path" value="$PROJECT_DIR$/project/model.py" /> | |
</component> | |
<component name="RunManager"> | |
<configuration name="model" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true"> | |
<module name="ML-CL-DDoS" /> | |
<option name="INTERPRETER_OPTIONS" value="" /> | |
<option name="PARENT_ENVS" value="true" /> | |
<envs> | |
<env name="PYTHONUNBUFFERED" value="1" /> | |
</envs> | |
<option name="SDK_HOME" value="D:\anaconda3\python.exe" /> | |
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/project" /> | |
<option name="IS_MODULE_SDK" value="false" /> | |
<option name="ADD_CONTENT_ROOTS" value="true" /> | |
<option name="ADD_SOURCE_ROOTS" value="true" /> | |
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" /> | |
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/project/model.py" /> | |
<option name="PARAMETERS" value="" /> | |
<option name="SHOW_COMMAND_LINE" value="false" /> | |
<option name="EMULATE_TERMINAL" value="false" /> | |
<option name="MODULE_MODE" value="false" /> | |
<option name="REDIRECT_INPUT" value="false" /> | |
<option name="INPUT_FILE" value="" /> | |
<method v="2" /> | |
</configuration> | |
</component> | |
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" /> | |
<component name="TaskManager"> | |
<task active="true" id="Default" summary="Default task"> | |
<changelist id="cc085207-44ad-4a53-9562-ba3ad78d023c" name="Default Changelist" comment="" /> | |
<created>1603073740902</created> | |
<option name="number" value="Default" /> | |
<option name="presentableId" value="Default" /> | |
<updated>1603073740902</updated> | |
<workItem from="1603073741948" duration="38000" /> | |
<workItem from="1603073787657" duration="70000" /> | |
</task> | |
<servers /> | |
</component> | |
<component name="TypeScriptGeneratedFilesManager"> | |
<option name="version" value="3" /> | |
</component> | |
<component name="WindowStateProjectService"> | |
<state x="1374" y="354" width="1092" height="716" key="#com.intellij.execution.impl.EditConfigurationsDialog" timestamp="1603073856779"> | |
<screen x="0" y="0" width="2560" height="1400" /> | |
</state> | |
<state x="1374" y="354" width="1092" height="716" key="#com.intellij.execution.impl.EditConfigurationsDialog/[email protected]" timestamp="1603073856779" /> | |
<state x="1294" y="566" key="FileChooserDialogImpl" timestamp="1603073843440"> | |
<screen x="0" y="0" width="2560" height="1400" /> | |
</state> | |
<state x="1294" y="566" key="FileChooserDialogImpl/[email protected]" timestamp="1603073843440" /> | |
</component> | |
<component name="com.intellij.coverage.CoverageDataManagerImpl"> | |
<SUITE FILE_PATH="coverage/ML_CL_DDoS$model.coverage" NAME="model Coverage Results" MODIFIED="1603073856931" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/project" /> | |
</component> | |
</project> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def aggregate_grads(grads, backend): | |
"""Aggregate model gradients to models. | |
Args: | |
data: a list of grads' information | |
item format: | |
{ | |
'n_samples': xxx, | |
'named_grads': xxx, | |
} | |
Return: | |
named grads: { | |
'layer_name1': grads1, | |
'layer_name2': grads2, | |
... | |
} | |
""" | |
total_grads = {} | |
n_total_samples = 0 | |
for gradinfo in grads: | |
n_samples = gradinfo['n_samples'] | |
for k, v in gradinfo['named_grads'].items(): | |
if k not in total_grads: | |
total_grads[k] = [] | |
total_grads[k].append(v * n_samples) | |
n_total_samples += n_samples | |
gradients = {} | |
for k, v in total_grads.items(): | |
gradients[k] = backend.sum(v, dim=0) / n_total_samples | |
return gradients |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC | |
from abc import abstractmethod | |
import os | |
import numpy as np | |
import torch | |
from aggretator import aggregate_grads | |
def random_str(n): | |
return hex(int.from_bytes(os.urandom(n), byteorder='big'))[2:] | |
class ModelBase(ABC): | |
def __init__(self, **kwargs): | |
for k in kwargs: | |
setattr(self, k, kwargs[k]) | |
@abstractmethod | |
def update_grads(self): | |
pass | |
@abstractmethod | |
def load_model(self, path): | |
pass | |
@abstractmethod | |
def save_model(self, path): | |
pass | |
class PytorchModel(ModelBase): | |
def __init__(self, | |
torch, | |
model_class, | |
init_model_path: str = '', | |
lr: float = 0.01, | |
optim_name: str = 'Adam', | |
cuda: bool = False): | |
"""Pytorch 封装. | |
参数: | |
torch: torch 库 | |
model_class: 训练模型类 | |
init_model_path: 初始模型路径 | |
lr: 学习率 | |
optim_name: 优化器类名称 | |
cuda: 是否需要使用cuda | |
""" | |
self.torch = torch | |
self.model_class = model_class | |
self.init_model_path = init_model_path | |
self.lr = lr | |
self.optim_name = optim_name | |
self.cuda = cuda | |
self._init_params() | |
def _init_params(self): | |
self.model = self.model_class() | |
if self.init_model_path: | |
self.model.load_state_dict(self.torch.load(self.init_model_path)) | |
if self.cuda and self.torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
self.optimizer = getattr(self.torch.optim, | |
self.optim_name)(self.model.parameters(), | |
lr=self.lr) | |
def update_grads(self, grads): | |
self.optimizer.zero_grad() | |
for k, v in self.model.named_parameters(): | |
v.grad = grads[k].type(v.dtype) | |
self.optimizer.step() | |
def update_params(self, params): | |
for k, v in self.model.named_parameters(): | |
v[:] = params[k] | |
return self.model | |
def load_model(self, path, force_reload=False): | |
if force_reload is False and self.load_from_path == path: | |
return | |
self.load_from_path = path | |
self.model.load_static_dict(self.torch.load(path)) | |
def save_model(self, path): | |
base = os.path.dirname(path) | |
if not os.path.exists(base): | |
os.makedirs(base) | |
self.torch.save(self.model.state_dict(), path) | |
return path | |
class BaseBackend(ABC): | |
@abstractmethod | |
def mean(self, data): | |
data = np.array(data) | |
return data.mean(axis=0) | |
class NumpyBackend(BaseBackend): | |
def mean(self, data): | |
return super().mean(data=data) | |
class PytorchBackend(BaseBackend): | |
def __init__(self, torch, cuda=False): | |
self.torch = torch | |
if cuda: | |
if self.torch.cuda.is_available(): | |
self.cuda = True | |
else: | |
self.cuda = False | |
def mean(self, data, dim=0): | |
return self.torch.tensor( | |
data, | |
device=self.torch.cuda.current_device() if self.cuda else None, | |
).mean(dim=dim) | |
def sum(self, data, dim=0): | |
return self.torch.tensor( | |
data, | |
device=self.torch.cuda.current_device() if self.cuda else None, | |
).sum(dim=dim) | |
def _check_model(self, model): | |
if not isinstance(model, PytorchModel): | |
raise ValueError( | |
"model must be type of PytorchModel not {}".format( | |
type(model))) | |
def update_grads(self, model, grads): | |
self._check_model(model=model) | |
return model.update_grads(grads=grads) | |
def update_params(self, model, params): | |
self._check_model(model=model) | |
return model.update_params(params=params) | |
def load_model(self, model, path, force_reload=False): | |
self._check_model(model=model) | |
return model.load_model(path=path, force_reload=force_reload) | |
def save_model(self, model, path): | |
self._check_model(model=model) | |
return model.save_model(path) | |
class Aggregator(object): | |
def __init__(self, model, backend): | |
self.model = model | |
self.backend = backend | |
class FederatedAveragingGrads(Aggregator): | |
def __init__(self, model, framework=None): | |
self.framework = framework or getattr(model, 'framework') | |
if framework is None or framework == 'numpy': | |
backend = NumpyBackend | |
elif framework == 'pytorch': | |
backend = PytorchBackend(torch=torch) | |
else: | |
raise ValueError( | |
'Framework {} is not supported!'.format(framework)) | |
super().__init__(model, backend) | |
def aggregate_grads(self, grads): | |
"""Aggregate model gradients to models. | |
Args: | |
data: a list of grads' information | |
item format: | |
{ | |
'n_samples': xxx, | |
'named_grads': xxx, | |
} | |
""" | |
self.backend.update_grads(self.model, | |
grads=aggregate_grads(grads=grads, | |
backend=self.backend)) | |
def save_model(self, path): | |
return self.backend.save_model(self.model, path=path) | |
def load_model(self, path, force_reload=False): | |
return self.backend.load_model(self.model, | |
path=path, | |
force_reload=force_reload) | |
def __call__(self, grads): | |
"""Aggregate grads. | |
Args: | |
grads -> list: grads is a list of either the actual grad info | |
or the absolute file path of grad info. | |
""" | |
if not grads: | |
return | |
if not isinstance(grads, list): | |
raise ValueError('grads should be a list, not {}'.format( | |
type(grads))) | |
actual_grads = grads | |
return self.aggregate_grads(grads=actual_grads) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
class FLModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc1 = nn.Linear(79, 256) | |
self.fc5 = nn.Linear(256, 14) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.fc5(x) | |
output = F.log_softmax(x, dim=1) | |
return output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
import os | |
import shutil | |
import unittest | |
import numpy as np | |
from sklearn.metrics import classification_report | |
import torch | |
import torch.nn.functional as F | |
from context import FederatedAveragingGrads | |
from context import PytorchModel | |
from learning_model import FLModel | |
from preprocess import get_test_loader | |
from preprocess import UserRoundData | |
from train import user_round_train | |
class ParameterServer(object): | |
def __init__(self, init_model_path, testworkdir): | |
self.round = 0 | |
self.rounds_info = {} | |
self.rounds_model_path = {} | |
self.current_round_grads = [] | |
self.init_model_path = init_model_path | |
self.aggr = FederatedAveragingGrads( | |
model=PytorchModel(torch=torch, | |
model_class=FLModel, | |
init_model_path=self.init_model_path, | |
optim_name='Adam'), | |
framework='pytorch', | |
) | |
self.testworkdir = testworkdir | |
if not os.path.exists(self.testworkdir): | |
os.makedirs(self.testworkdir) | |
def get_latest_model(self): | |
if not self.rounds_model_path: | |
return self.init_model_path | |
if self.round in self.rounds_model_path: | |
return self.rounds_model_path[self.round] | |
return self.rounds_model_path[self.round - 1] | |
def receive_grads_info(self, grads): | |
self.current_round_grads.append(grads) | |
def aggregate(self): | |
self.aggr(self.current_round_grads) | |
path = os.path.join(self.testworkdir, | |
'round-{round}-model.md'.format(round=self.round)) | |
self.rounds_model_path[self.round] = path | |
if (self.round - 1) in self.rounds_model_path: | |
if os.path.exists(self.rounds_model_path[self.round - 1]): | |
os.remove(self.rounds_model_path[self.round - 1]) | |
info = self.aggr.save_model(path=path) | |
self.round += 1 | |
self.current_round_grads = [] | |
return info | |
class FedAveragingGradsTestSuit(unittest.TestCase): | |
RESULT_DIR = 'result' | |
N_VALIDATION = 10000 | |
TEST_BASE_DIR = '/tmp/' | |
def setUp(self): | |
self.seed = 0 | |
self.use_cuda = False | |
self.batch_size = 64 | |
self.test_batch_size = 1000 | |
self.lr = 0.001 | |
self.n_max_rounds = 10000 | |
self.log_interval = 10 | |
self.n_round_samples = 1600 | |
self.testbase = self.TEST_BASE_DIR | |
self.testworkdir = os.path.join(self.testbase, 'competetion-test') | |
if not os.path.exists(self.testworkdir): | |
os.makedirs(self.testworkdir) | |
self.init_model_path = os.path.join(self.testworkdir, 'init_model.md') | |
torch.manual_seed(self.seed) | |
if not os.path.exists(self.init_model_path): | |
torch.save(FLModel().state_dict(), self.init_model_path) | |
self.ps = ParameterServer(init_model_path=self.init_model_path, | |
testworkdir=self.testworkdir) | |
if not os.path.exists(self.RESULT_DIR): | |
os.makedirs(self.RESULT_DIR) | |
self.urd = UserRoundData() | |
self.n_users = self.urd.n_users | |
def _clear(self): | |
shutil.rmtree(self.testworkdir) | |
def tearDown(self): | |
self._clear() | |
def test_federated_averaging(self): | |
torch.manual_seed(self.seed) | |
device = torch.device("cuda" if self.use_cuda else "cpu") | |
training_start = datetime.now() | |
model = None | |
for r in range(1, self.n_max_rounds + 1): | |
path = self.ps.get_latest_model() | |
start = datetime.now() | |
for u in range(0, self.n_users): | |
model = FLModel() | |
model.load_state_dict(torch.load(path)) | |
model = model.to(device) | |
x, y = self.urd.round_data( | |
user_idx=u, | |
n_round=r, | |
n_round_samples=self.n_round_samples) | |
grads = user_round_train(X=x, Y=y, model=model, device=device) | |
self.ps.receive_grads_info(grads=grads) | |
self.ps.aggregate() | |
print('\nRound {} cost: {}, total training cost: {}'.format( | |
r, | |
datetime.now() - start, | |
datetime.now() - training_start, | |
)) | |
if model is not None and r % 200 == 0: | |
self.predict(model, | |
device, | |
self.urd.uniform_random_loader(self.N_VALIDATION), | |
prefix="Train") | |
self.save_testdata_prediction(model=model, device=device) | |
if model is not None: | |
self.save_testdata_prediction(model=model, device=device) | |
def save_prediction(self, predition): | |
if isinstance(predition, (np.ndarray, )): | |
predition = predition.reshape(-1).tolist() | |
with open(os.path.join(self.RESULT_DIR, 'result.txt'), 'w') as fout: | |
fout.writelines(os.linesep.join([str(n) for n in predition])) | |
def save_testdata_prediction(self, model, device): | |
loader = get_test_loader(batch_size=1000) | |
prediction = [] | |
with torch.no_grad(): | |
for data in loader: | |
pred = model(data.to(device)).argmax(dim=1, keepdim=True) | |
prediction.extend(pred.reshape(-1).tolist()) | |
self.save_prediction(prediction) | |
def predict(self, model, device, test_loader, prefix=""): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
prediction = [] | |
real = [] | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
test_loss += F.nll_loss( | |
output, target, | |
reduction='sum').item() # sum up batch loss | |
pred = output.argmax( | |
dim=1, | |
keepdim=True) # get the index of the max log-probability | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
prediction.extend(pred.reshape(-1).tolist()) | |
real.extend(target.reshape(-1).tolist()) | |
test_loss /= len(test_loader.dataset) | |
acc = 100. * correct / len(test_loader.dataset) | |
print(classification_report(real, prediction)) | |
print( | |
'{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( | |
prefix, test_loss, correct, len(test_loader.dataset), acc), ) | |
def suite(): | |
suite = unittest.TestSuite() | |
suite.addTest(FedAveragingGradsTestSuit('test_federated_averaging')) | |
return suite | |
def main(): | |
runner = unittest.TextTestRunner() | |
runner.run(suite()) | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.utils.data | |
TRAINDATA_DIR = 'N:\dataset\media_competetions_manual-uploaded-datasets_train.tar\media_competetions_manual-uploaded-datasets_train\train/' | |
TESTDATA_PATH = './test/testing-X.pkl' | |
ATTACK_TYPES = { | |
'snmp': 0, | |
'portmap': 1, | |
'syn': 2, | |
'dns': 3, | |
'ssdp': 4, | |
'webddos': 5, | |
'mssql': 6, | |
'tftp': 7, | |
'ntp': 8, | |
'udplag': 9, | |
'ldap': 10, | |
'netbios': 11, | |
'udp': 12, | |
'benign': 13, | |
} | |
class CompDataset(object): | |
def __init__(self, X, Y): | |
self.X = X | |
self.Y = Y | |
self._data = [(x, y) for x, y in zip(X, Y)] | |
def __getitem__(self, idx): | |
return self._data[idx] | |
def __len__(self): | |
return len(self._data) | |
def extract_features(data, has_label=True): | |
data['SimillarHTTP'] = 0. | |
if has_label: | |
return data.iloc[:, -80:-1] | |
return data.iloc[:, -79:] | |
class UserRoundData(object): | |
def __init__(self): | |
self.data_dir = TRAINDATA_DIR | |
self._user_datasets = [] | |
self.attack_types = ATTACK_TYPES | |
self._load_data() | |
def _get_data(self, fpath): | |
if not fpath.endswith('csv'): | |
return | |
print('Load User Data: ', os.path.basename(fpath)) | |
data = pd.read_csv(fpath, skipinitialspace=True, low_memory=False) | |
x = extract_features(data) | |
y = np.array([ | |
self.attack_types[t.split('_')[-1].replace('-', '').lower()] | |
for t in data.iloc[:, -1] | |
]) | |
x = x.to_numpy().astype(np.float32) | |
x[x == np.inf] = 1. | |
x[np.isnan(x)] = 0. | |
return ( | |
x, | |
y, | |
) | |
def _load_data(self): | |
_user_datasets = [] | |
self._user_datasets = [] | |
for root, dirs, fnames in os.walk(self.data_dir): | |
for fname in fnames: | |
# each file is for each user | |
# user data can not be shared among users | |
data = self._get_data(os.path.join(root, fname)) | |
if data is not None: | |
_user_datasets.append(data) | |
for x, y in _user_datasets: | |
self._user_datasets.append(( | |
x, | |
y, | |
)) | |
self.n_users = len(_user_datasets) | |
def round_data(self, user_idx, n_round, n_round_samples=-1): | |
"""Generate data for user of user_idx at round n_round. | |
Args: | |
user_idx: int, in [0, self.n_users) | |
n_round: int, round number | |
""" | |
if n_round_samples == -1: | |
return self._user_datasets[user_idx] | |
n_samples = len(self._user_datasets[user_idx][1]) | |
choices = np.random.choice(n_samples, min(n_samples, n_round_samples)) | |
return self._user_datasets[user_idx][0][choices], self._user_datasets[ | |
user_idx][1][choices] | |
def uniform_random_loader(self, n_samples, batch_size=1000): | |
X, Y = [], [] | |
n_samples_each_user = n_samples // len(self._user_datasets) | |
if n_samples_each_user <= 0: | |
n_samples_each_user = 1 | |
for idx in range(len(self._user_datasets)): | |
x, y = self.round_data(user_idx=idx, | |
n_round=0, | |
n_round_samples=n_samples_each_user) | |
X.append(x) | |
Y.append(y) | |
data = CompDataset(X=np.concatenate(X), Y=np.concatenate(Y)) | |
train_loader = torch.utils.data.DataLoader( | |
data, | |
batch_size=min(batch_size, n_samples), | |
shuffle=True, | |
) | |
return train_loader | |
def get_test_loader(batch_size=1000): | |
with open(TESTDATA_PATH, 'rb') as fin: | |
data = pickle.load(fin) | |
test_loader = torch.utils.data.DataLoader( | |
data['X'], | |
batch_size=batch_size, | |
shuffle=False, | |
) | |
return test_loader |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from preprocess import CompDataset | |
def user_round_train(X, Y, model, device, debug=False): | |
data = CompDataset(X=X, Y=Y) | |
train_loader = torch.utils.data.DataLoader( | |
data, | |
batch_size=320, | |
shuffle=True, | |
) | |
model.train() | |
correct = 0 | |
prediction = [] | |
real = [] | |
total_loss = 0 | |
model = model.to(device) | |
for batch_idx, (data, target) in enumerate(train_loader): | |
data, target = data.to(device), target.to(device) | |
# import ipdb | |
# ipdb.set_trace() | |
# print(data.shape, target.shape) | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
total_loss += loss | |
loss.backward() | |
pred = output.argmax( | |
dim=1, keepdim=True) # get the index of the max log-probability | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
prediction.extend(pred.reshape(-1).tolist()) | |
real.extend(target.reshape(-1).tolist()) | |
grads = {'n_samples': data.shape[0], 'named_grads': {}} | |
for name, param in model.named_parameters(): | |
grads['named_grads'][name] = param.grad.detach().cpu().numpy() | |
if debug: | |
print('Training Loss: {:<10.2f}, accuracy: {:<8.2f}'.format( | |
total_loss, 100. * correct / len(train_loader.dataset))) | |
return grads |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment