Skip to content

Instantly share code, notes, and snippets.

@yattom
Last active August 12, 2024 03:35
Show Gist options
  • Save yattom/3a3433352589214c89dc98ba5cd95722 to your computer and use it in GitHub Desktop.
Save yattom/3a3433352589214c89dc98ba5cd95722 to your computer and use it in GitHub Desktop.
An experimental framework for input form and validation
import csv
from typing import Any
from dataclasses import dataclass
from pathlib import Path
import pytest
class Validation:
@dataclass
class Result:
NO_TRANSLATED_VALUE = object()
is_valid: bool
error_message: str
translated_value: Any = NO_TRANSLATED_VALUE
@staticmethod
def success(translated_value=NO_TRANSLATED_VALUE):
return Validation.Result(True, '', translated_value)
def validate(self, value, field_name) -> 'Validation.Result':
return Validation.Result.success()
class MandatoryValidation(Validation):
def validate(self, value, field_name) -> 'Validation.Result':
if value is None:
return Validation.Result(False, f'{field_name} is required')
return Validation.Result.success()
class NumericValidation(Validation):
def validate(self, value, field_name) -> 'Validation.Result':
if value is None:
return Validation.Result.success()
try:
return Validation.Result.success(translated_value=int(value))
except ValueError:
return Validation.Result(False, f'{field_name} must be a number')
class FileValidation(Validation):
def __init__(self, base_dir, allowed_extensions):
self.base_dir = base_dir
self.allowed_extensions = allowed_extensions
def validate(self, value, field_name) -> 'Validation.Result':
if value is None:
return Validation.Result.success()
if not isinstance(value, str):
return Validation.Result(False, f'{field_name} must be a string')
translated_path = Path(self.base_dir) / Path(value)
if not translated_path.is_file():
return Validation.Result(False, f'{field_name} not found')
# validate suffix against allowed extensions
if translated_path.suffix not in self.allowed_extensions:
return Validation.Result(False, f'{field_name} must end with {self.allowed_extensions}')
return Validation.Result.success(translated_value=translated_path)
class CsvFileValidataion(Validation):
def validate(self, value, field_name) -> 'Validation.Result':
if value is None:
return Validation.Result.success()
match value:
case str():
return self.validate_file(value, field_name)
case Path():
return self.validate_file(value, field_name)
case dict():
return self.validate_unzipped_file(value, field_name)
case _:
return Validation.Result(False, f'{field_name} must be a string or Path or dict')
def validate_file(self, value, field_name):
file_path = Path(value)
if not file_path.is_file():
return Validation.Result(False, f'{field_name} not found')
if not file_path.suffix != '.csv':
return Validation.Result(False, f'{field_name} must end with .csv')
csv_contents = []
import csv
with open(file_path, 'r') as f:
reader = csv.reader(f)
for row in reader:
csv_contents.append(row)
return Validation.Result.success(translated_value=csv_contents)
def validate_unzipped_file(self, value, field_name):
csv_contents = []
for file, contents in value.items():
if not file.endswith('.csv'):
return Validation.Result(False, f'file {file} in {field_name} must end with .csv')
print(contents)
for row in csv.reader(contents.splitlines()):
csv_contents.append(row)
return Validation.Result.success(translated_value=csv_contents)
class ZipFileValidation(Validation):
def validate(self, value, field_name) -> 'Validation.Result':
if value is None:
return Validation.Result.success()
file_path = Path(value)
if file_path.suffix != '.zip':
return Validation.Result(False, f'{field_name} must end with .zip')
# unzip the file to retrieve its contents on memory
import zipfile
file_contents = {}
with zipfile.ZipFile(file_path, 'r') as zip_ref:
for file in zip_ref.namelist():
with zip_ref.open(file) as f:
file_contents[file] = f.read().decode('utf-8')
return Validation.Result.success(translated_value=file_contents)
class Field:
def __init__(self, field_type, validations: list[Validation]):
self.field_type = field_type
self.validations = validations
def field(field_type, validations: list[Validation] | None = None):
if validations is None:
validations = []
return Field(field_type, validations)
class Form:
TYPE_ERROR_MESSAGES = {
int: 'must be a number',
}
def __init__(self):
self._fields: dict[str, Field] = {}
self._field_values = {}
self._populate_fields()
def _populate_fields(self):
for attr_name in list(self.__class__.__dict__):
attr = getattr(self.__class__, attr_name)
if isinstance(attr, Field):
self._fields[attr_name] = attr
def is_valid(self):
r: Validation.Result
return all((r.is_valid for r in self.validate()[1]))
def validate(self) -> [dict[str, Any], list[Validation.Result]]:
translated_values = {}
results = []
for field_name, field in self._fields.items():
value = self._field_values.get(field_name)
results.append(self._validate_field_type(field_name, value, field.field_type))
field_results, translated_value = self._validate_field(field_name, value, field.validations)
results.extend(field_results)
translated_values[field_name] = translated_value
return translated_values, results
def _validate_field_type(self, field_name, value, field_type) -> Validation.Result | None:
if not self._validate_type(value, field_type):
return Validation.Result(
False,
f'{field_name} {Form.TYPE_ERROR_MESSAGES.get(field_type, "has an invalid type")}')
return Validation.Result.success()
def _validate_field(self, field_name, value, validations) -> tuple[list[Validation.Result], Any]:
results = []
for validation in validations:
result = validation.validate(value, field_name)
results.append(result)
if result.is_valid and result.translated_value != Validation.Result.NO_TRANSLATED_VALUE:
value = result.translated_value
return results, value
def _validate_type(self, value, field_type):
if value is None:
return True
try:
field_type(value)
return True
except (ValueError, TypeError):
return False
def get(self, field_id: str | Field):
if isinstance(field_id, Field):
for name, field in self._fields.items():
if field is field_id:
value = self._field_values.get(name, None)
break
else:
raise ValueError(f'Field "{field_id}" not found')
else:
field = self._fields.get(field_id, None)
value = self._field_values.get(field_id, None)
return field.field_type(value)
def __setattr__(self, name, value):
if name in self.__dict__.get('_fields', {}):
self._field_values[name] = value
else:
super().__setattr__(name, value)
def test_空のForm():
class MyForm(Form):
pass
form = MyForm()
assert form.is_valid() == True
必須 = MandatoryValidation()
class TestForm:
def test_値取得(self):
class MyForm(Form):
text_field_optional = field(str)
text_field = field(str, [必須])
int_field = field(int)
form = MyForm()
form.text_field_optional = 'value1'
form.text_field = 'value2'
form.int_field = '123'
assert form.get('text_field_optional') == 'value1'
assert form.get('text_field') == 'value2'
assert form.get('int_field') == 123
def test_値取得_安全な書き方(self):
class MyForm(Form):
text_field_optional = field(str)
text_field = field(str, [必須])
int_field = field(int)
form = MyForm()
form.text_field_optional = 'value1'
form.text_field = 'value2'
form.int_field = '123'
assert form.get(form.text_field_optional) == 'value1'
assert form.get(form.text_field) == 'value2'
assert form.get(form.int_field) == 123
def test_バリデーション(self):
class MyForm(Form):
text_field_optional = field(str)
text_field = field(str, [必須])
int_field = field(int)
form = MyForm()
form.text_field_optional = 'value1'
form.int_field = 'abc'
assert form.is_valid() == False
def test_バリデーション結果取得(self):
class MyForm(Form):
text_field_optional = field(str)
text_field = field(str, [必須])
int_field = field(int)
form = MyForm()
form.text_field_optional = 'value1'
form.int_field = 'abc'
result: list[Validation.Result]
_, result = form.validate()
assert Validation.Result(False, 'int_field must be a number') in result
assert Validation.Result(False, 'text_field is required') in result
class TestForm自由入力:
class MyForm(Form):
field1 = field(str)
def test_未入力(self):
form = self.MyForm()
assert form.is_valid()
def test_入力済み(self):
form = self.MyForm()
form.field1 = 'value'
assert form.is_valid()
class TestForm入力必須:
class MyForm(Form):
field1 = field(str, [必須])
def test_未入力(self):
form = self.MyForm()
assert not form.is_valid()
def test_入力済み(self):
form = self.MyForm()
form.field1 = 'value'
assert form.is_valid()
def test_値取得(self):
form = self.MyForm()
form.field1 = 'value'
assert form.get(form.field1) == 'value'
class TestForm数字のみ:
class MyForm(Form):
int1 = field(int)
def test_未入力(self):
class MyForm(Form):
int1 = field(int)
form = MyForm()
assert form.is_valid()
def test_数字以外(self):
class MyForm(Form):
int1 = field(int)
form = MyForm()
form.int1 = 'value'
assert not form.is_valid()
def test_数字(self):
class MyForm(Form):
int1 = field(int)
form = MyForm()
form.int1 = '123'
assert form.is_valid()
def test_値取得(self):
form = self.MyForm()
form.int1 = '123'
assert form.get(form.int1) == 123
@pytest.fixture
def temp_base_dir(tmp_path):
path = tmp_path / 'data'
path.mkdir()
return path
class TestFormファイル:
def test_値取得(self, temp_base_dir):
class MyForm(Form):
zipped_csv_file = field(str,
[
必須,
FileValidation(base_dir=temp_base_dir, allowed_extensions=['.zip']),
ZipFileValidation(),
CsvFileValidataion(),
])
zipped_csv_file = create_temp_zipped_csv_file(temp_base_dir, 'sample.zip', [['a', 'b'], [1, 2]])
form = MyForm()
form.zipped_csv_file = 'sample.zip'
translated_values, results = form.validate()
print(results)
assert translated_values['zipped_csv_file'] == [['a', 'b'], ['1', '2']]
def create_temp_zipped_csv_file(base_dir: Path, file_name: str, contents: list[list[str]]):
import csv
import zipfile
from io import StringIO
file_path = base_dir / file_name
with zipfile.ZipFile(file_path, 'w') as zip_ref:
with zip_ref.open('sample.csv', 'w') as f:
csv_buffer = StringIO()
writer = csv.writer(csv_buffer)
for row in contents:
writer.writerow(row)
f.write(csv_buffer.getvalue().encode('utf-8'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment