Last active
August 12, 2024 03:35
-
-
Save yattom/3a3433352589214c89dc98ba5cd95722 to your computer and use it in GitHub Desktop.
An experimental framework for input form and validation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
from typing import Any | |
from dataclasses import dataclass | |
from pathlib import Path | |
import pytest | |
class Validation: | |
@dataclass | |
class Result: | |
NO_TRANSLATED_VALUE = object() | |
is_valid: bool | |
error_message: str | |
translated_value: Any = NO_TRANSLATED_VALUE | |
@staticmethod | |
def success(translated_value=NO_TRANSLATED_VALUE): | |
return Validation.Result(True, '', translated_value) | |
def validate(self, value, field_name) -> 'Validation.Result': | |
return Validation.Result.success() | |
class MandatoryValidation(Validation): | |
def validate(self, value, field_name) -> 'Validation.Result': | |
if value is None: | |
return Validation.Result(False, f'{field_name} is required') | |
return Validation.Result.success() | |
class NumericValidation(Validation): | |
def validate(self, value, field_name) -> 'Validation.Result': | |
if value is None: | |
return Validation.Result.success() | |
try: | |
return Validation.Result.success(translated_value=int(value)) | |
except ValueError: | |
return Validation.Result(False, f'{field_name} must be a number') | |
class FileValidation(Validation): | |
def __init__(self, base_dir, allowed_extensions): | |
self.base_dir = base_dir | |
self.allowed_extensions = allowed_extensions | |
def validate(self, value, field_name) -> 'Validation.Result': | |
if value is None: | |
return Validation.Result.success() | |
if not isinstance(value, str): | |
return Validation.Result(False, f'{field_name} must be a string') | |
translated_path = Path(self.base_dir) / Path(value) | |
if not translated_path.is_file(): | |
return Validation.Result(False, f'{field_name} not found') | |
# validate suffix against allowed extensions | |
if translated_path.suffix not in self.allowed_extensions: | |
return Validation.Result(False, f'{field_name} must end with {self.allowed_extensions}') | |
return Validation.Result.success(translated_value=translated_path) | |
class CsvFileValidataion(Validation): | |
def validate(self, value, field_name) -> 'Validation.Result': | |
if value is None: | |
return Validation.Result.success() | |
match value: | |
case str(): | |
return self.validate_file(value, field_name) | |
case Path(): | |
return self.validate_file(value, field_name) | |
case dict(): | |
return self.validate_unzipped_file(value, field_name) | |
case _: | |
return Validation.Result(False, f'{field_name} must be a string or Path or dict') | |
def validate_file(self, value, field_name): | |
file_path = Path(value) | |
if not file_path.is_file(): | |
return Validation.Result(False, f'{field_name} not found') | |
if not file_path.suffix != '.csv': | |
return Validation.Result(False, f'{field_name} must end with .csv') | |
csv_contents = [] | |
import csv | |
with open(file_path, 'r') as f: | |
reader = csv.reader(f) | |
for row in reader: | |
csv_contents.append(row) | |
return Validation.Result.success(translated_value=csv_contents) | |
def validate_unzipped_file(self, value, field_name): | |
csv_contents = [] | |
for file, contents in value.items(): | |
if not file.endswith('.csv'): | |
return Validation.Result(False, f'file {file} in {field_name} must end with .csv') | |
print(contents) | |
for row in csv.reader(contents.splitlines()): | |
csv_contents.append(row) | |
return Validation.Result.success(translated_value=csv_contents) | |
class ZipFileValidation(Validation): | |
def validate(self, value, field_name) -> 'Validation.Result': | |
if value is None: | |
return Validation.Result.success() | |
file_path = Path(value) | |
if file_path.suffix != '.zip': | |
return Validation.Result(False, f'{field_name} must end with .zip') | |
# unzip the file to retrieve its contents on memory | |
import zipfile | |
file_contents = {} | |
with zipfile.ZipFile(file_path, 'r') as zip_ref: | |
for file in zip_ref.namelist(): | |
with zip_ref.open(file) as f: | |
file_contents[file] = f.read().decode('utf-8') | |
return Validation.Result.success(translated_value=file_contents) | |
class Field: | |
def __init__(self, field_type, validations: list[Validation]): | |
self.field_type = field_type | |
self.validations = validations | |
def field(field_type, validations: list[Validation] | None = None): | |
if validations is None: | |
validations = [] | |
return Field(field_type, validations) | |
class Form: | |
TYPE_ERROR_MESSAGES = { | |
int: 'must be a number', | |
} | |
def __init__(self): | |
self._fields: dict[str, Field] = {} | |
self._field_values = {} | |
self._populate_fields() | |
def _populate_fields(self): | |
for attr_name in list(self.__class__.__dict__): | |
attr = getattr(self.__class__, attr_name) | |
if isinstance(attr, Field): | |
self._fields[attr_name] = attr | |
def is_valid(self): | |
r: Validation.Result | |
return all((r.is_valid for r in self.validate()[1])) | |
def validate(self) -> [dict[str, Any], list[Validation.Result]]: | |
translated_values = {} | |
results = [] | |
for field_name, field in self._fields.items(): | |
value = self._field_values.get(field_name) | |
results.append(self._validate_field_type(field_name, value, field.field_type)) | |
field_results, translated_value = self._validate_field(field_name, value, field.validations) | |
results.extend(field_results) | |
translated_values[field_name] = translated_value | |
return translated_values, results | |
def _validate_field_type(self, field_name, value, field_type) -> Validation.Result | None: | |
if not self._validate_type(value, field_type): | |
return Validation.Result( | |
False, | |
f'{field_name} {Form.TYPE_ERROR_MESSAGES.get(field_type, "has an invalid type")}') | |
return Validation.Result.success() | |
def _validate_field(self, field_name, value, validations) -> tuple[list[Validation.Result], Any]: | |
results = [] | |
for validation in validations: | |
result = validation.validate(value, field_name) | |
results.append(result) | |
if result.is_valid and result.translated_value != Validation.Result.NO_TRANSLATED_VALUE: | |
value = result.translated_value | |
return results, value | |
def _validate_type(self, value, field_type): | |
if value is None: | |
return True | |
try: | |
field_type(value) | |
return True | |
except (ValueError, TypeError): | |
return False | |
def get(self, field_id: str | Field): | |
if isinstance(field_id, Field): | |
for name, field in self._fields.items(): | |
if field is field_id: | |
value = self._field_values.get(name, None) | |
break | |
else: | |
raise ValueError(f'Field "{field_id}" not found') | |
else: | |
field = self._fields.get(field_id, None) | |
value = self._field_values.get(field_id, None) | |
return field.field_type(value) | |
def __setattr__(self, name, value): | |
if name in self.__dict__.get('_fields', {}): | |
self._field_values[name] = value | |
else: | |
super().__setattr__(name, value) | |
def test_空のForm(): | |
class MyForm(Form): | |
pass | |
form = MyForm() | |
assert form.is_valid() == True | |
必須 = MandatoryValidation() | |
class TestForm: | |
def test_値取得(self): | |
class MyForm(Form): | |
text_field_optional = field(str) | |
text_field = field(str, [必須]) | |
int_field = field(int) | |
form = MyForm() | |
form.text_field_optional = 'value1' | |
form.text_field = 'value2' | |
form.int_field = '123' | |
assert form.get('text_field_optional') == 'value1' | |
assert form.get('text_field') == 'value2' | |
assert form.get('int_field') == 123 | |
def test_値取得_安全な書き方(self): | |
class MyForm(Form): | |
text_field_optional = field(str) | |
text_field = field(str, [必須]) | |
int_field = field(int) | |
form = MyForm() | |
form.text_field_optional = 'value1' | |
form.text_field = 'value2' | |
form.int_field = '123' | |
assert form.get(form.text_field_optional) == 'value1' | |
assert form.get(form.text_field) == 'value2' | |
assert form.get(form.int_field) == 123 | |
def test_バリデーション(self): | |
class MyForm(Form): | |
text_field_optional = field(str) | |
text_field = field(str, [必須]) | |
int_field = field(int) | |
form = MyForm() | |
form.text_field_optional = 'value1' | |
form.int_field = 'abc' | |
assert form.is_valid() == False | |
def test_バリデーション結果取得(self): | |
class MyForm(Form): | |
text_field_optional = field(str) | |
text_field = field(str, [必須]) | |
int_field = field(int) | |
form = MyForm() | |
form.text_field_optional = 'value1' | |
form.int_field = 'abc' | |
result: list[Validation.Result] | |
_, result = form.validate() | |
assert Validation.Result(False, 'int_field must be a number') in result | |
assert Validation.Result(False, 'text_field is required') in result | |
class TestForm自由入力: | |
class MyForm(Form): | |
field1 = field(str) | |
def test_未入力(self): | |
form = self.MyForm() | |
assert form.is_valid() | |
def test_入力済み(self): | |
form = self.MyForm() | |
form.field1 = 'value' | |
assert form.is_valid() | |
class TestForm入力必須: | |
class MyForm(Form): | |
field1 = field(str, [必須]) | |
def test_未入力(self): | |
form = self.MyForm() | |
assert not form.is_valid() | |
def test_入力済み(self): | |
form = self.MyForm() | |
form.field1 = 'value' | |
assert form.is_valid() | |
def test_値取得(self): | |
form = self.MyForm() | |
form.field1 = 'value' | |
assert form.get(form.field1) == 'value' | |
class TestForm数字のみ: | |
class MyForm(Form): | |
int1 = field(int) | |
def test_未入力(self): | |
class MyForm(Form): | |
int1 = field(int) | |
form = MyForm() | |
assert form.is_valid() | |
def test_数字以外(self): | |
class MyForm(Form): | |
int1 = field(int) | |
form = MyForm() | |
form.int1 = 'value' | |
assert not form.is_valid() | |
def test_数字(self): | |
class MyForm(Form): | |
int1 = field(int) | |
form = MyForm() | |
form.int1 = '123' | |
assert form.is_valid() | |
def test_値取得(self): | |
form = self.MyForm() | |
form.int1 = '123' | |
assert form.get(form.int1) == 123 | |
@pytest.fixture | |
def temp_base_dir(tmp_path): | |
path = tmp_path / 'data' | |
path.mkdir() | |
return path | |
class TestFormファイル: | |
def test_値取得(self, temp_base_dir): | |
class MyForm(Form): | |
zipped_csv_file = field(str, | |
[ | |
必須, | |
FileValidation(base_dir=temp_base_dir, allowed_extensions=['.zip']), | |
ZipFileValidation(), | |
CsvFileValidataion(), | |
]) | |
zipped_csv_file = create_temp_zipped_csv_file(temp_base_dir, 'sample.zip', [['a', 'b'], [1, 2]]) | |
form = MyForm() | |
form.zipped_csv_file = 'sample.zip' | |
translated_values, results = form.validate() | |
print(results) | |
assert translated_values['zipped_csv_file'] == [['a', 'b'], ['1', '2']] | |
def create_temp_zipped_csv_file(base_dir: Path, file_name: str, contents: list[list[str]]): | |
import csv | |
import zipfile | |
from io import StringIO | |
file_path = base_dir / file_name | |
with zipfile.ZipFile(file_path, 'w') as zip_ref: | |
with zip_ref.open('sample.csv', 'w') as f: | |
csv_buffer = StringIO() | |
writer = csv.writer(csv_buffer) | |
for row in contents: | |
writer.writerow(row) | |
f.write(csv_buffer.getvalue().encode('utf-8')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment