Created
June 11, 2021 10:02
-
-
Save DDoSolitary/3daacb30015f7fd14c0e4fa4ca751bbe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from __future__ import annotations | |
import functools | |
import itertools | |
import json | |
import multiprocessing | |
import random | |
import re | |
import string | |
import subprocess | |
import tempfile | |
import networkx as nx | |
from abc import ABC | |
from argparse import ArgumentParser | |
from collections import defaultdict, deque | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Optional, Union | |
REQUEST_COUNT = 1000 | |
NAME_LEN = 10 | |
NAME_ALPHABET = string.ascii_letters + string.digits | |
NOTICE_LEN = 100 | |
NOTICE_ALPHABET = NAME_ALPHABET | |
MAX_ID = 1 << 31 | |
MIN_ID = -MAX_ID | |
MAX_AGE = 200 + 1 | |
MIN_AGE = 0 | |
MAX_VALUE = 1000 + 1 | |
MIN_VALUE = 0 | |
MAX_SOCIAL_VALUE = 1000 + 1 | |
MIN_SOCIAL_VALUE = -1000 | |
MAX_EMOJI_ID = 10000 + 1 | |
MIN_EMOJI_ID = 0 | |
MAX_MONEY = 200 + 1 | |
MIN_MONEY = 0 | |
MAX_HEAT_LIMIT = 10 | |
MIN_HEAT_LIMIT = 0 | |
class InputRequest(ABC): | |
op: str | |
def __init__(self, op: str): | |
self.op = op | |
@staticmethod | |
def parse(s: str) -> InputRequest: | |
fields = s.split() | |
op = fields[0] | |
if op == 'ap': | |
return AddPersonRequest(int(fields[1]), fields[2], int(fields[3])) | |
elif op == 'ar': | |
return AddRelationRequest(int(fields[1]), int(fields[2]), int(fields[3])) | |
elif op == 'qv': | |
return QueryValueRequest(int(fields[1]), int(fields[2])) | |
elif op == 'cn': | |
return CompareNameRequest(int(fields[1]), int(fields[2])) | |
elif op == 'qnr': | |
return QueryNameRankRequest(int(fields[1])) | |
elif op == 'qps': | |
return QueryPeopleSumRequest() | |
elif op == 'qci': | |
return QueryCircleRequest(int(fields[1]), int(fields[2])) | |
elif op == 'qbs': | |
return QueryBlockSumRequest() | |
elif op == 'ag': | |
return AddGroupRequest(int(fields[1])) | |
elif op == 'atg': | |
return AddToGroupRequest(int(fields[1]), int(fields[2])) | |
elif op == 'qgs': | |
return QueryGroupSumRequest() | |
elif op == 'qgps': | |
return QueryGroupPeopleSumRequest(int(fields[1])) | |
elif op == 'qgvs': | |
return QueryGroupValueSumRequest(int(fields[1])) | |
elif op == 'qgam': | |
return QueryGroupAgeMeanRequest(int(fields[1])) | |
elif op == 'qgav': | |
return QueryGroupAgeVarRequest(int(fields[1])) | |
elif op == 'dfg': | |
return DelFromGroupRequest(int(fields[1]), int(fields[2])) | |
elif op == 'am': | |
return AddMessageRequest(int(fields[1]), int(fields[2]), int(fields[3]), int(fields[4]), int(fields[5])) | |
elif op == 'sm': | |
return SendMessageRequest(int(fields[1])) | |
elif op == 'qsv': | |
return QuerySocialValueRequest(int(fields[1])) | |
elif op == 'qrm': | |
return QueryReceivedMessagesRequest(int(fields[1])) | |
elif op == 'arem': | |
return AddRedEnvelopeMessageRequest( | |
int(fields[1]), int(fields[2]), int(fields[3]), int(fields[4]), int(fields[5])) | |
elif op == 'anm': | |
return AddNoticeMessageRequest( | |
int(fields[1]), fields[2], int(fields[3]), int(fields[4]), int(fields[5])) | |
elif op == 'aem': | |
return AddEmojiMessageRequest( | |
int(fields[1]), int(fields[2]), int(fields[3]), int(fields[4]), int(fields[5])) | |
elif op == 'sei': | |
return StoreEmojiIdRequest(int(fields[1])) | |
elif op == 'qp': | |
return QueryPopularityRequest(int(fields[1])) | |
elif op == 'dce': | |
return DeleteColdEmojiRequest(int(fields[1])) | |
elif op == 'qm': | |
return QueryMoneyRequest(int(fields[1])) | |
elif op == 'sim': | |
return SendIndirectMessageRequest(int(fields[1])) | |
else: | |
raise ValueError('invalid op') | |
class AddPersonRequest(InputRequest): | |
id: int | |
name: str | |
age: int | |
def __init__(self, _id: int, name: str, age: int): | |
super().__init__('ap') | |
self.id = _id | |
self.name = name | |
self.age = age | |
def __str__(self) -> str: | |
return f'{self.op} {self.id} {self.name} {self.age}' | |
class AddRelationRequest(InputRequest): | |
id1: int | |
id2: int | |
value: int | |
def __init__(self, id1: int, id2: int, value: int): | |
super().__init__('ar') | |
self.id1 = id1 | |
self.id2 = id2 | |
self.value = value | |
def __str__(self) -> str: | |
return f'{self.op} {self.id1} {self.id2} {self.value}' | |
class QueryValueRequest(InputRequest): | |
id1: int | |
id2: int | |
def __init__(self, id1: int, id2: int): | |
super().__init__('qv') | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self) -> str: | |
return f'{self.op} {self.id1} {self.id2}' | |
class CompareNameRequest(InputRequest): | |
id1: int | |
id2: int | |
def __init__(self, id1: int, id2: int): | |
super().__init__('cn') | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self) -> str: | |
return f'{self.op} {self.id1} {self.id2}' | |
class QueryNameRankRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qnr') | |
self.id = _id | |
def __str__(self) -> str: | |
return f'{self.op} {self.id}' | |
class QueryPeopleSumRequest(InputRequest): | |
def __init__(self): | |
super().__init__('qps') | |
def __str__(self) -> str: | |
return self.op | |
class QueryCircleRequest(InputRequest): | |
id1: int | |
id2: int | |
def __init__(self, id1: int, id2: int): | |
super().__init__('qci') | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self) -> str: | |
return f'{self.op} {self.id1} {self.id2}' | |
class QueryBlockSumRequest(InputRequest): | |
def __init__(self): | |
super().__init__('qbs') | |
def __str__(self) -> str: | |
return self.op | |
class AddGroupRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('ag') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class AddToGroupRequest(InputRequest): | |
id1: int | |
id2: int | |
def __init__(self, id1: int, id2: int): | |
super().__init__('atg') | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self): | |
return f'{self.op} {self.id1} {self.id2}' | |
class QueryGroupSumRequest(InputRequest): | |
def __init__(self): | |
super().__init__('qgs') | |
def __str__(self) -> str: | |
return self.op | |
class QueryGroupPeopleSumRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qgps') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class QueryGroupValueSumRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qgvs') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class QueryGroupAgeMeanRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qgam') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class QueryGroupAgeVarRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qgav') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class DelFromGroupRequest(InputRequest): | |
id1: int | |
id2: int | |
def __init__(self, id1: int, id2: int): | |
super().__init__('dfg') | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self): | |
return f'{self.op} {self.id1} {self.id2}' | |
class AddMessageRequest(InputRequest): | |
id: int | |
social_value: int | |
type: int | |
id1: int | |
id2: int | |
def __init__(self, _id: int, social_value: int, _type: int, id1: int, id2: int): | |
super().__init__('am') | |
self.id = _id | |
self.social_value = social_value | |
self.type = _type | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self): | |
return f'{self.op} {self.id} {self.social_value} {self.type} {self.id1} {self.id2}' | |
class SendMessageRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('sm') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class QuerySocialValueRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qsv') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class QueryReceivedMessagesRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qrm') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class AddRedEnvelopeMessageRequest(InputRequest): | |
id: int | |
money: int | |
type: int | |
id1: int | |
id2: int | |
def __init__(self, _id: int, money: int, _type: int, id1: int, id2: int): | |
super().__init__('arem') | |
self.id = _id | |
self.money = money | |
self.type = _type | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self): | |
return f'{self.op} {self.id} {self.money} {self.type} {self.id1} {self.id2}' | |
class AddNoticeMessageRequest(InputRequest): | |
id: int | |
string: str | |
type: int | |
id1: int | |
id2: int | |
def __init__(self, _id: int, _string: str, _type: int, id1: int, id2: int): | |
super().__init__('anm') | |
self.id = _id | |
self.string = _string | |
self.type = _type | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self): | |
return f'{self.op} {self.id} {self.string} {self.type} {self.id1} {self.id2}' | |
class AddEmojiMessageRequest(InputRequest): | |
id: int | |
emoji_id: int | |
type: int | |
id1: int | |
id2: int | |
def __init__(self, _id: int, emoji_id: int, _type: int, id1: int, id2: int): | |
super().__init__('aem') | |
self.id = _id | |
self.emoji_id = emoji_id | |
self.type = _type | |
self.id1 = id1 | |
self.id2 = id2 | |
def __str__(self): | |
return f'{self.op} {self.id} {self.emoji_id} {self.type} {self.id1} {self.id2}' | |
class StoreEmojiIdRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('sei') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class QueryPopularityRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qp') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class DeleteColdEmojiRequest(InputRequest): | |
limit: int | |
def __init__(self, limit: int): | |
super().__init__('dce') | |
self.limit = limit | |
def __str__(self): | |
return f'{self.op} {self.limit}' | |
class QueryMoneyRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('qm') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
class SendIndirectMessageRequest(InputRequest): | |
id: int | |
def __init__(self, _id: int): | |
super().__init__('sim') | |
self.id = _id | |
def __str__(self): | |
return f'{self.op} {self.id}' | |
@dataclass | |
class Person: | |
id: int | |
name: str | |
age: int | |
social_value: int | |
money: int | |
messages: deque[Message] | |
def __hash__(self) -> int: | |
return hash(id) | |
def __eq__(self, other: object) -> bool: | |
if not isinstance(other, Person): | |
return NotImplemented | |
return self.id == other.id | |
@dataclass | |
class Group: | |
id: int | |
members: set[Person] | |
@dataclass | |
class Message: | |
id: int | |
social_value: int | |
src: Person | |
dst: Union[Person, Group] | |
def __str__(self): | |
return 'Ordinary message' | |
@dataclass | |
class NoticeMessage(Message): | |
notice: str | |
def __str__(self): | |
return f'notice: {self.notice}' | |
@dataclass | |
class EmojiMessage(Message): | |
emoji_id: int | |
def __str__(self): | |
return f'Emoji: {self.emoji_id}' | |
@dataclass | |
class RedEnvelopeMessage(Message): | |
money: int | |
def __str__(self): | |
return f'RedEnvelope: {self.money}' | |
class ExceptionCounter: | |
tag: str | |
global_counter: int | |
person_counters: defaultdict[int, int] | |
def __init__(self, tag: str): | |
self.tag = tag | |
self.global_counter = 0 | |
self.person_counters = defaultdict(int) | |
def count(self, _id: int) -> str: | |
self.global_counter += 1 | |
self.person_counters[_id] += 1 | |
return f'{self.tag}-{self.global_counter}, {_id}-{self.person_counters[_id]}' | |
def count_both(self, id1: int, id2: int): | |
self.global_counter += 1 | |
if id2 < id1: | |
id1, id2 = id2, id1 | |
self.person_counters[id1] += 1 | |
if id1 != id2: | |
self.person_counters[id2] += 1 | |
return f'{self.tag}-{self.global_counter}, {id1}-{self.person_counters[id1]}, {id2}-{self.person_counters[id2]}' | |
class PeopleNetwork: | |
people: dict[int, Person] | |
groups: dict[int, Group] | |
messages: dict[int, Message] | |
emojis: dict[int, int] | |
graph: nx.Graph | |
pinfCounter: ExceptionCounter | |
epiCounter: ExceptionCounter | |
rnfCounter: ExceptionCounter | |
erCounter: ExceptionCounter | |
ginfCounter: ExceptionCounter | |
egiCounter: ExceptionCounter | |
minfCounter: ExceptionCounter | |
emiCounter: ExceptionCounter | |
einfCounter: ExceptionCounter | |
eeiCounter: ExceptionCounter | |
def __init__(self): | |
self.people = dict() | |
self.groups = dict() | |
self.messages = dict() | |
self.emojis = dict() | |
self.graph = nx.Graph() | |
self.pinfCounter = ExceptionCounter('pinf') | |
self.epiCounter = ExceptionCounter('epi') | |
self.rnfCounter = ExceptionCounter('rnf') | |
self.erCounter = ExceptionCounter('er') | |
self.ginfCounter = ExceptionCounter('ginf') | |
self.egiCounter = ExceptionCounter('egi') | |
self.minfCounter = ExceptionCounter('minf') | |
self.emiCounter = ExceptionCounter('emi') | |
self.einfCounter = ExceptionCounter('einf') | |
self.eeiCounter = ExceptionCounter('eei') | |
def process_request(self, req: InputRequest) -> str: | |
if isinstance(req, AddPersonRequest): | |
if req.id in self.people: | |
return self.epiCounter.count(req.id) | |
p = Person(req.id, req.name, req.age, 0, 0, deque()) | |
self.people[req.id] = p | |
self.graph.add_edge(p, p, value=0) | |
return 'Ok' | |
elif isinstance(req, AddRelationRequest): | |
p1 = self.people.get(req.id1) | |
if p1 is None: | |
return self.pinfCounter.count(req.id1) | |
p2 = self.people.get(req.id2) | |
if p2 is None: | |
return self.pinfCounter.count(req.id2) | |
if self.graph.has_edge(p1, p2): | |
return self.erCounter.count_both(req.id1, req.id2) | |
self.graph.add_edge(p1, p2, value=req.value) | |
return 'Ok' | |
elif isinstance(req, QueryValueRequest): | |
p1 = self.people.get(req.id1) | |
if p1 is None: | |
return self.pinfCounter.count(req.id1) | |
p2 = self.people.get(req.id2) | |
if p2 is None: | |
return self.pinfCounter.count(req.id2) | |
attrs = self.graph.get_edge_data(p1, p2) | |
if attrs is None: | |
return self.rnfCounter.count_both(req.id1, req.id2) | |
return str(attrs['value']) | |
elif isinstance(req, CompareNameRequest): | |
p1 = self.people.get(req.id1) | |
if p1 is None: | |
return self.pinfCounter.count(req.id1) | |
p2 = self.people.get(req.id2) | |
if p2 is None: | |
return self.pinfCounter.count(req.id2) | |
if p1.name < p2.name: | |
return '<' | |
if p1.name == p2.name: | |
return '=' | |
return '>' | |
elif isinstance(req, QueryNameRankRequest): | |
p = self.people.get(req.id) | |
if p is None: | |
return self.pinfCounter.count(req.id) | |
return str(sum(1 for o in self.people.values() if o.name < p.name) + 1) | |
elif isinstance(req, QueryPeopleSumRequest): | |
return str(len(self.people)) | |
elif isinstance(req, QueryCircleRequest): | |
p1 = self.people.get(req.id1) | |
if p1 is None: | |
return self.pinfCounter.count(req.id1) | |
p2 = self.people.get(req.id2) | |
if p2 is None: | |
return self.pinfCounter.count(req.id2) | |
return '1' if nx.has_path(self.graph, p1, p2) else '0' | |
elif isinstance(req, QueryBlockSumRequest): | |
return str(nx.number_connected_components(self.graph)) | |
elif isinstance(req, AddGroupRequest): | |
if req.id in self.groups: | |
return self.egiCounter.count(req.id) | |
self.groups[req.id] = Group(req.id, set()) | |
return 'Ok' | |
elif isinstance(req, AddToGroupRequest): | |
g = self.groups.get(req.id2) | |
if g is None: | |
return self.ginfCounter.count(req.id2) | |
p = self.people.get(req.id1) | |
if p is None: | |
return self.pinfCounter.count(req.id1) | |
if p in g.members: | |
return self.epiCounter.count(req.id1) | |
g.members.add(p) | |
return 'Ok' | |
elif isinstance(req, QueryGroupSumRequest): | |
return str(len(self.groups)) | |
elif isinstance(req, QueryGroupPeopleSumRequest): | |
g = self.groups.get(req.id) | |
if g is None: | |
return self.ginfCounter.count(req.id) | |
return str(len(g.members)) | |
elif isinstance(req, QueryGroupValueSumRequest): | |
g = self.groups.get(req.id) | |
if g is None: | |
return self.ginfCounter.count(req.id) | |
return str(2 * sum(v for _, _, v in self.graph.subgraph(g.members).edges.data('value'))) | |
elif isinstance(req, QueryGroupAgeMeanRequest): | |
g = self.groups.get(req.id) | |
if g is None: | |
return self.ginfCounter.count(req.id) | |
if len(g.members) == 0: | |
return '0' | |
return str(sum(p.age for p in g.members) // len(g.members)) | |
elif isinstance(req, QueryGroupAgeVarRequest): | |
g = self.groups.get(req.id) | |
if g is None: | |
return self.ginfCounter.count(req.id) | |
if len(g.members) == 0: | |
return '0' | |
mean = sum(p.age for p in g.members) // len(g.members) | |
return str(sum((p.age - mean) ** 2 for p in g.members) // len(g.members)) | |
elif isinstance(req, DelFromGroupRequest): | |
g = self.groups.get(req.id2) | |
if g is None: | |
return self.ginfCounter.count(req.id2) | |
p = self.people.get(req.id1) | |
if p is None: | |
return self.pinfCounter.count(req.id1) | |
if p not in g.members: | |
return self.epiCounter.count(req.id1) | |
g.members.remove(p) | |
return 'Ok' | |
elif isinstance(req, AddMessageRequest) or isinstance(req, AddRedEnvelopeMessageRequest) or \ | |
isinstance(req, AddNoticeMessageRequest) or isinstance(req, AddEmojiMessageRequest): | |
if req.type == 1 and req.id2 not in self.groups: | |
return 'Group does not exist' | |
if req.id1 not in self.people or (req.type == 0 and req.id2 not in self.people): | |
return 'The person with this number does not exist' | |
if req.id in self.messages: | |
return self.emiCounter.count(req.id) | |
src = self.people[req.id1] | |
dst = self.people[req.id2] if req.type == 0 else self.groups[req.id2] | |
msg: Message | |
if isinstance(req, AddMessageRequest): | |
msg = Message(req.id, req.social_value, src, dst) | |
elif isinstance(req, AddRedEnvelopeMessageRequest): | |
msg = RedEnvelopeMessage(req.id, req.money * 5, src, dst, req.money) | |
elif isinstance(req, AddNoticeMessageRequest): | |
msg = NoticeMessage(req.id, len(req.string), src, dst, req.string) | |
elif isinstance(req, AddEmojiMessageRequest): | |
if req.emoji_id not in self.emojis: | |
return self.einfCounter.count(req.emoji_id) | |
msg = EmojiMessage(req.id, req.emoji_id, src, dst, req.emoji_id) | |
else: | |
raise Exception('unreachable code') | |
if isinstance(dst, Person) and src == dst: | |
return self.epiCounter.count(src.id) | |
self.messages[msg.id] = msg | |
return 'Ok' | |
elif isinstance(req, SendMessageRequest): | |
m = self.messages.get(req.id) | |
if m is None: | |
return self.minfCounter.count(req.id) | |
if isinstance(m.dst, Person): | |
if not self.graph.has_edge(m.src, m.dst): | |
return self.rnfCounter.count_both(m.src.id, m.dst.id) | |
m.src.social_value += m.social_value | |
m.dst.social_value += m.social_value | |
if isinstance(m, RedEnvelopeMessage): | |
m.src.money -= m.money | |
m.dst.money += m.money | |
m.dst.messages.appendleft(m) | |
else: | |
if m.src not in m.dst.members: | |
return self.pinfCounter.count(m.src.id) | |
for gp in m.dst.members: | |
gp.social_value += m.social_value | |
if isinstance(m, RedEnvelopeMessage): | |
money = m.money // len(m.dst.members) | |
m.src.money -= money * (len(m.dst.members) - 1) | |
for gp in m.dst.members: | |
if gp != m.src: | |
gp.money += money | |
if isinstance(m, EmojiMessage): | |
self.emojis[m.emoji_id] += 1 | |
self.messages.pop(m.id) | |
return 'Ok' | |
elif isinstance(req, QuerySocialValueRequest): | |
p = self.people.get(req.id) | |
if p is None: | |
return self.pinfCounter.count(req.id) | |
return str(p.social_value) | |
elif isinstance(req, QueryReceivedMessagesRequest): | |
p = self.people.get(req.id) | |
if p is None: | |
return self.pinfCounter.count(req.id) | |
if len(p.messages) == 0: | |
return 'None' | |
return '; '.join(str(msg) for msg in itertools.islice(p.messages, 0, 4)) | |
elif isinstance(req, StoreEmojiIdRequest): | |
if req.id in self.emojis: | |
return self.eeiCounter.count(req.id) | |
self.emojis[req.id] = 0 | |
return 'Ok' | |
elif isinstance(req, QueryPopularityRequest): | |
if req.id not in self.emojis: | |
return self.einfCounter.count(req.id) | |
return str(self.emojis[req.id]) | |
elif isinstance(req, DeleteColdEmojiRequest): | |
self.emojis = {k: v for k, v in self.emojis.items() if v >= req.limit} | |
self.messages = { | |
k: v for k, v in self.messages.items() | |
if isinstance(v, EmojiMessage) and v.emoji_id in self.emojis} | |
return str(len(self.emojis)) | |
elif isinstance(req, QueryMoneyRequest): | |
p = self.people.get(req.id) | |
if p is None: | |
return self.pinfCounter.count(req.id) | |
return str(p.money) | |
elif isinstance(req, SendIndirectMessageRequest): | |
m = self.messages.get(req.id) | |
if m is None or not isinstance(m.dst, Person): | |
return self.minfCounter.count(req.id) | |
try: | |
dis = nx.shortest_path_length(self.graph, m.src, m.dst, 'value') | |
except nx.NetworkXNoPath: | |
return '-1' | |
m.src.social_value += m.social_value | |
m.dst.social_value += m.social_value | |
if isinstance(m, RedEnvelopeMessage): | |
m.src.money -= m.money | |
m.dst.money += m.money | |
elif isinstance(m, EmojiMessage): | |
self.emojis[m.emoji_id] += 1 | |
m.dst.messages.appendleft(m) | |
self.messages.pop(m.id) | |
return str(dis) | |
else: | |
raise ValueError('invalid request') | |
@dataclass | |
class TestError(Exception): | |
reason: str | |
input: str | |
ans: list[str] | |
proc: Optional[subprocess.CompletedProcess] | |
err_line: Optional[int] | |
def gen_test_case() -> (list[InputRequest], list[str]): | |
requests: list[InputRequest] = [] | |
responses: list[str] = [] | |
network = PeopleNetwork() | |
def gen_id() -> int: | |
return random.randrange(MIN_ID, MAX_ID) | |
def gen_used_person_id() -> int: | |
return random.choice(tuple(network.people.keys())) | |
def gen_new_person_id() -> int: | |
if random.randrange(10) == 0 and len(network.people) > 0: | |
return gen_used_person_id() | |
return gen_id() | |
def gen_person_id() -> int: | |
if random.randrange(10) == 0 or len(network.people) == 0: | |
return gen_id() | |
return gen_used_person_id() | |
def gen_person_id_pair() -> (int, int): | |
rand = random.randrange(10) | |
if rand == 0 or len(network.people) == 0: | |
if random.randrange(5) == 0: | |
ret = gen_id() | |
return ret, ret | |
return gen_id(), gen_id() | |
if rand == 1: | |
id1 = gen_used_person_id() | |
id2 = gen_id() | |
if random.randrange(0, 2) == 0: | |
id1, id2 = id2, id1 | |
return id1, id2 | |
if random.randrange(5) == 0: | |
ret = gen_used_person_id() | |
return ret, ret | |
return gen_used_person_id(), gen_used_person_id() | |
def gen_used_group_id() -> int: | |
return random.choice(tuple(network.groups.keys())) | |
def gen_new_group_id() -> int: | |
if random.randrange(10) == 0 and len(network.groups) > 0: | |
return gen_used_group_id() | |
return gen_id() | |
def gen_group_id() -> int: | |
if random.randrange(10) == 0 or len(network.groups) == 0: | |
return gen_id() | |
return gen_used_group_id() | |
def gen_used_msg_id() -> int: | |
return random.choice(tuple(network.messages.keys())) | |
def gen_new_msg_id() -> int: | |
if random.randrange(10) == 0 and len(network.messages) > 0: | |
return gen_used_msg_id() | |
return gen_id() | |
def gen_msg_id() -> int: | |
if random.randrange(10) == 0 or len(network.messages) == 0: | |
return gen_id() | |
return gen_used_msg_id() | |
def gen_name() -> str: | |
if random.randrange(10) == 0 and len(network.people) > 0: | |
return random.choice(tuple(p.name for p in network.people.values())) | |
ret = ''.join(random.choice(NAME_ALPHABET) for _ in range(NAME_LEN)) | |
return ret | |
def gen_age() -> int: | |
return random.randrange(MIN_AGE, MAX_AGE) | |
def gen_value() -> int: | |
return random.randrange(MIN_VALUE, MAX_VALUE) | |
def gen_social_value() -> int: | |
return random.randrange(MIN_SOCIAL_VALUE, MAX_SOCIAL_VALUE) | |
def gen_used_emoji_id() -> int: | |
return random.choice(tuple(network.emojis.keys())) | |
def gen_new_emoji_id() -> int: | |
if random.randrange(10) == 0 and len(network.emojis) > 0: | |
return gen_used_emoji_id() | |
return random.randrange(MIN_EMOJI_ID, MAX_EMOJI_ID) | |
def gen_emoji_id() -> int: | |
if random.randrange(10) == 0 or len(network.emojis) == 0: | |
return random.randrange(MIN_EMOJI_ID, MAX_EMOJI_ID) | |
return gen_used_emoji_id() | |
def gen_money() -> int: | |
return random.randrange(MIN_MONEY, MAX_MONEY) | |
def gen_notice() -> str: | |
return ''.join(random.choice(NOTICE_ALPHABET) for _ in range(NOTICE_LEN)) | |
def gen_heat_limit() -> int: | |
return random.randrange(MIN_HEAT_LIMIT, MAX_HEAT_LIMIT) | |
def gen_ap() -> AddPersonRequest: | |
return AddPersonRequest(gen_new_person_id(), gen_name(), gen_age()) | |
def gen_ar() -> AddRelationRequest: | |
return AddRelationRequest(*gen_person_id_pair(), gen_value()) | |
def gen_qv() -> QueryValueRequest: | |
return QueryValueRequest(*gen_person_id_pair()) | |
def gen_cn() -> CompareNameRequest: | |
return CompareNameRequest(*gen_person_id_pair()) | |
def gen_qnr() -> QueryNameRankRequest: | |
return QueryNameRankRequest(gen_person_id()) | |
def gen_qps() -> QueryPeopleSumRequest: | |
return QueryPeopleSumRequest() | |
def gen_qci() -> QueryCircleRequest: | |
return QueryCircleRequest(*gen_person_id_pair()) | |
def gen_qbs() -> QueryBlockSumRequest: | |
return QueryBlockSumRequest() | |
def gen_ag() -> AddGroupRequest: | |
return AddGroupRequest(gen_new_group_id()) | |
def gen_atg() -> AddToGroupRequest: | |
return AddToGroupRequest(gen_person_id(), gen_group_id()) | |
def gen_qgs() -> QueryGroupSumRequest: | |
return QueryGroupSumRequest() | |
def gen_qgps() -> QueryGroupPeopleSumRequest: | |
return QueryGroupPeopleSumRequest(gen_group_id()) | |
def gen_qgvs() -> QueryGroupValueSumRequest: | |
return QueryGroupValueSumRequest(gen_group_id()) | |
def gen_qgam() -> QueryGroupAgeMeanRequest: | |
return QueryGroupAgeMeanRequest(gen_group_id()) | |
def gen_qgav() -> QueryGroupAgeVarRequest: | |
return QueryGroupAgeVarRequest(gen_group_id()) | |
def gen_dfg() -> DelFromGroupRequest: | |
return DelFromGroupRequest(gen_person_id(), gen_group_id()) | |
def gen_am() -> Union[ | |
AddMessageRequest, AddRedEnvelopeMessageRequest, | |
AddNoticeMessageRequest, AddEmojiMessageRequest]: | |
_type = random.randrange(2) | |
if _type == 0: | |
id1 = gen_person_id() | |
p1 = network.people.get(id1) | |
if p1 is not None: | |
p2_list = tuple(network.graph.neighbors(p1)) | |
else: | |
p2_list = () | |
if random.randrange(10) == 0 or len(p2_list) == 0: | |
id2 = gen_person_id() | |
else: | |
id2 = random.choice(p2_list).id | |
else: | |
id2 = gen_group_id() | |
g = network.groups.get(id2) | |
if random.randrange(10) == 0 or g is None or len(g.members) == 0: | |
id1 = gen_person_id() | |
else: | |
id1 = random.choice(tuple(g.members)).id | |
_id = gen_new_msg_id() | |
cls = random.randrange(4) | |
if cls == 0: | |
return AddMessageRequest(_id, gen_social_value(), _type, id1, id2) | |
elif cls == 1: | |
return AddRedEnvelopeMessageRequest(_id, gen_money(), _type, id1, id2) | |
elif cls == 2: | |
return AddNoticeMessageRequest(_id, gen_notice(), _type, id1, id2) | |
else: | |
return AddEmojiMessageRequest(_id, gen_emoji_id(), _type, id1, id2) | |
def gen_sm() -> SendMessageRequest: | |
return SendMessageRequest(gen_msg_id()) | |
def gen_qsv() -> QuerySocialValueRequest: | |
return QuerySocialValueRequest(gen_person_id()) | |
def gen_qrm() -> QueryReceivedMessagesRequest: | |
return QueryReceivedMessagesRequest(gen_person_id()) | |
def gen_sei() -> StoreEmojiIdRequest: | |
return StoreEmojiIdRequest(gen_new_emoji_id()) | |
def gen_qp() -> QueryPopularityRequest: | |
return QueryPopularityRequest(gen_emoji_id()) | |
def gen_dce() -> DeleteColdEmojiRequest: | |
return DeleteColdEmojiRequest(gen_heat_limit()) | |
def gen_qm() -> QueryMoneyRequest: | |
return QueryMoneyRequest(gen_person_id()) | |
def gen_sim() -> SendIndirectMessageRequest: | |
return SendIndirectMessageRequest(gen_msg_id()) | |
for _ in range(REQUEST_COUNT): | |
req: InputRequest = random.choice(( | |
gen_ap, gen_ar, gen_qv, gen_cn, gen_qnr, gen_qps, gen_qci, gen_qbs, | |
gen_ag, gen_atg, gen_qgs, gen_qgps, gen_qgvs, gen_qgam, gen_qgav, gen_dfg, | |
gen_am, gen_sm, gen_qsv, gen_qrm, gen_sei, gen_qp, gen_dce, gen_qm, gen_sim))() | |
requests.append(req) | |
responses.append(network.process_request(req)) | |
return requests, responses | |
def parse_input_data(input_data: list[str]) -> (list[InputRequest], list[str]): | |
requests: list[InputRequest] = [] | |
responses: list[str] = [] | |
network = PeopleNetwork() | |
for line in input_data: | |
line = line.strip() | |
if len(line) == 0: | |
continue | |
req = InputRequest.parse(line) | |
requests.append(req) | |
responses.append(network.process_request(req)) | |
return requests, responses | |
def do_test(_, config): | |
if config['input'] is None: | |
requests, responses = gen_test_case() | |
else: | |
requests, responses = parse_input_data(config['input']) | |
req_str = ''.join(f'{req}\n' for req in requests) | |
subjects = config['subjects'] | |
errors = [] | |
for subject in subjects: | |
try: | |
try: | |
proc = subprocess.run( | |
subject['cmd'], | |
input=req_str, | |
capture_output=True, | |
text=True, | |
timeout=10 | |
) | |
except subprocess.TimeoutExpired: | |
raise TestError('Time Limit Exceeded', req_str, responses, None, None) | |
if proc.returncode != 0: | |
raise TestError('Runtime Error', req_str, responses, proc, None) | |
stdout = proc.stdout.splitlines() | |
err_line: Optional[int] = None | |
for i, (out_line, ans_line) in enumerate(itertools.zip_longest(stdout, responses)): | |
if out_line != ans_line: | |
err_line = i + 1 | |
break | |
if err_line is not None: | |
raise TestError('Wrong Answer', req_str, responses, proc, err_line) | |
except TestError as e: | |
err = dict( | |
subject=subject['name'], | |
reason=e.reason | |
) | |
err_data = dict( | |
input=e.input, | |
ans=''.join(line + '\n' for line in e.ans) | |
) | |
if e.proc is not None: | |
if e.proc.stdout is not None: | |
err_data['stdout'] = e.proc.stdout | |
if e.proc.stderr is not None: | |
err_data['stderr'] = e.proc.stderr | |
err['exit_code'] = e.proc.returncode | |
if e.err_line is not None: | |
err['error_line'] = e.err_line | |
if filter_error(dict(**err, **err_data), config['filters']): | |
log_dir = Path(tempfile.mkdtemp(dir=config.get('log_dir'))) | |
err['log_dir'] = str(log_dir) | |
for k, v in err_data.items(): | |
(log_dir / f'{k}.txt').write_text(v) | |
errors.append(err) | |
return errors | |
def compile_rule(rule: dict): | |
for k in rule.keys(): | |
if k != 'action': | |
rule[k] = re.compile(rule[k]) | |
def filter_error(err, rules): | |
for rule in rules: | |
matched = True | |
for key, pattern in rule.items(): | |
if key == 'action': | |
continue | |
value = err.get(key) | |
if value is None or pattern.search(value) is None: | |
matched = False | |
break | |
if matched: | |
action = rule['action'] | |
if action == 'accept': | |
return True | |
elif action == 'ignore': | |
return False | |
return True | |
def main(): | |
parser = ArgumentParser() | |
parser.add_argument('--config', '-c', required=True) | |
parser.add_argument('--log-dir', '-l') | |
group = parser.add_mutually_exclusive_group() | |
group.add_argument('--count', '-n', type=int, default=1) | |
group.add_argument('--input', '-i') | |
args = parser.parse_args() | |
with open(args.config) as f: | |
config = json.load(f) | |
config.update(vars(args)) | |
if 'filters' in config: | |
for rule in config['filters']: | |
compile_rule(rule) | |
else: | |
config['filters'] = [] | |
if config['input'] is not None: | |
with open(config['input']) as f: | |
config['input'] = f.readlines() | |
errors = do_test(None, config) | |
else: | |
errors = [] | |
idx = 0 | |
test_func = functools.partial(do_test, config=config) | |
with multiprocessing.Pool() as pool: | |
for res in pool.imap_unordered(test_func, range(args.count)): | |
print(f'#{idx}: {len(res)}') | |
idx += 1 | |
errors.extend(res) | |
print(json.dumps(errors, indent=2)) | |
if __name__ == '__main__': | |
main() | |
# vim: ts=4:sw=4:noet |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment