Skip to content

Instantly share code, notes, and snippets.

@todesking
Created April 28, 2010 08:25
Show Gist options
  • Save todesking/381869 to your computer and use it in GitHub Desktop.
Save todesking/381869 to your computer and use it in GitHub Desktop.
Contract for Java
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Map;
import org.apache.commons.lang.ArrayUtils;
import org.junit.Test;
import com.google.common.collect.Maps;
public class ContractTest {
public static interface Storage {
public boolean exists(int id);
public void insert(int id, String value);
public void delete(int id);
public String get(int id);
public static class WithContract extends Contract<Storage> {
protected void before_insert(int id, String value) {
contract("既に存在するIDはinsertできない", !target.exists(id));
contract("valueは非null", value != null);
}
protected void after_insert(Void result, int id, String value) {
contract("insert後はidが存在する", target.exists(id));
}
protected void before_get(int id) {
contract("存在しないidはgetできない", target.exists(id));
}
protected void after_get(String result, int id) {
contract("getはnullを返さない", result != null);
}
protected void before_delete(int id) {
contract("存在しないidは削除できない", target.exists(id));
}
protected void after_delete(Void result, int id) {
contract("削除されたidは存在しない", !target.exists(id));
}
private WithContract(Storage target) {
super(Storage.class, target);
}
public static Storage from(Storage target) {
return new WithContract(target).createProxy();
}
}
}
public static class StorageImpl implements Storage {
private final Map<Integer, String> storage = Maps.newHashMap();
public void delete(int id) {
storage.remove(id);
}
public boolean exists(int id) {
return storage.containsKey(id);
}
public String get(int id) {
return storage.get(id);
}
public void insert(int id, String value) {
storage.put(id, value);
}
}
@Test
public void test_defaultAccessorImpl() throws Exception {
final Storage target = Storage.WithContract.from(new StorageImpl());
assertThat(target.exists(100), is(false));
target.insert(100, "item-100");
assertThat(target.exists(100), is(true));
assertThat(target.get(100), is("item-100"));
target.delete(100);
assertThat(target.exists(100), is(false));
}
@Test
public void test_wrongBehavior() throws Exception {
final Storage mocked = mock(Storage.class);
when(mocked.exists(100)).thenReturn(false);
when(mocked.get(100)).thenReturn("item-100");
assertThat(mocked.get(100), is("item-100"));
final Storage mockWithContract = Storage.WithContract.from(mocked);
when(mocked.exists(100)).thenReturn(true);
assertThat(mockWithContract.get(100), is("item-100"));
when(mocked.exists(100)).thenReturn(false);
try {
mockWithContract.get(100);
fail();
} catch (ContractError e) {
}
when(mocked.exists(10)).thenReturn(false, true);
try {
mockWithContract.insert(10, null);
fail();
} catch (ContractError e) {
}
when(mocked.exists(10)).thenReturn(false, false);
try {
mockWithContract.insert(10, "hoge");
fail();
} catch (ContractError e) {
}
when(mocked.get(10)).thenReturn(null);
try {
mockWithContract.get(10);
fail();
} catch (ContractError e) {
}
}
/**
* ケイヤク エラー
*/
public static class ContractError extends Error {
private static final long serialVersionUID = 1L;
public ContractError(String msg) {
super(msg);
}
}
/**
* 契約をあれするための基底クラス
*
* @param <T>
*/
static abstract class Contract<T> {
protected final T target;
protected final Class<T> targetClass;
protected void contract(String description, boolean cond) {
if (!cond)
throw new ContractError(description);
}
protected void contract(boolean cond) {
contract("contract error", cond);
}
public <U extends T> Contract(Class<T> targetClass, U target) {
this.target = target;
this.targetClass = targetClass;
}
public T createProxy() {
final InvocationHandler h = new InvocationHandler() {
public Object invoke(Object proxy, Method method, Object[] args)
throws Throwable {
System.out.println("proxy.invoked: method="
+ method.getName());
final Method beforeContract = findBeforeContractFor(method);
if (beforeContract != null) {
System.out.println("before_" + method.getName());
invokeContract(beforeContract, args);
}
final Object result = method.invoke(target, args);
final Method afterContract = findAfterContractFor(method);
if (afterContract != null) {
System.out.println("after_" + method.getName());
invokeContract(afterContract, ArrayUtils.addAll(
new Object[] { result },
args));
}
return result;
}
};
@SuppressWarnings("unchecked")
T withContact =
(T) Proxy.newProxyInstance(
targetClass.getClassLoader(),
new Class<?>[] { targetClass },
h);
return withContact;
}
private Method findMethodByName(String name) {
for (Method m : this.getClass().getDeclaredMethods()) {
if (m.getName().equals(name))
return m;
}
return null;
}
private Method findBeforeContractFor(Method method) {
return findMethodByName("before_" + method.getName());
}
private Method findAfterContractFor(Method method) {
return findMethodByName("after_" + method.getName());
}
private void invokeContract(final Method contract, Object[] args)
throws IllegalAccessException, Throwable {
try {
contract.invoke(this, args);
} catch (InvocationTargetException e) {
throw e.getTargetException();
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment