Created
January 19, 2022 21:33
-
-
Save trzy/8eb1452665248b91064d0efbfc35dde9 to your computer and use it in GitHub Desktop.
PyTorch Memory Leak and Redundant .detach()
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import numpy as np | |
import random | |
import torch as t | |
from torch import nn | |
from torchvision.ops import nms | |
from torch.nn import functional as F | |
from torchvision.ops import RoIPool | |
from torchvision.models import vgg16 | |
def create_optimizer(model): | |
params = [] | |
for key, value in dict(model.named_parameters()).items(): | |
if not value.requires_grad: | |
continue | |
if "weight" in key: | |
params += [{ "params": [value], "weight_decay": 0 }] | |
return t.optim.SGD(params, lr = 1e-3, momentum = 0.9) | |
def run(): | |
model = FasterRCNNModel( | |
num_classes = 21, | |
allow_edge_proposals = True, | |
dropout_probability = 0 | |
).cuda() | |
optimizer = create_optimizer(model) | |
# Train forever | |
while True: | |
loss = model.train_step( | |
optimizer = optimizer, | |
image_data = t.from_numpy(__image_data).unsqueeze(dim = 0).float().cuda(), | |
anchor_map = __anchor_map.astype(float), | |
anchor_valid_map = __anchor_valid_map.astype(float), | |
gt_rpn_map = t.from_numpy(__gt_rpn_map).unsqueeze(dim = 0).float().cuda(), | |
gt_rpn_object_indices = [ __gt_rpn_object_indices ], | |
gt_rpn_background_indices = [ __gt_rpn_background_indices ], | |
gt_boxes = [ __gt_boxes ] | |
) | |
class FasterRCNNModel(nn.Module): | |
@dataclass | |
class Loss: | |
rpn_class: float | |
rpn_regression: float | |
detector_class: float | |
detector_regression: float | |
total: float | |
def __init__(self, num_classes, rpn_minibatch_size = 256, proposal_batch_size = 128, allow_edge_proposals = True, dropout_probability = 0): | |
super().__init__() | |
# Constants | |
self._num_classes = num_classes | |
self._rpn_minibatch_size = rpn_minibatch_size | |
self._proposal_batch_size = proposal_batch_size | |
self._detector_box_delta_means = [ 0, 0, 0, 0 ] | |
self._detector_box_delta_stds = [ 0.1, 0.1, 0.2, 0.2 ] | |
# Network stages | |
self._stage1_feature_extractor = FeatureExtractor() | |
self._stage2_region_proposal_network = RegionProposalNetwork(allow_edge_proposals = allow_edge_proposals) | |
self._stage3_detector_network = DetectorNetwork(num_classes = num_classes, dropout_probability = dropout_probability) | |
def train_step(self, optimizer, image_data, anchor_map, anchor_valid_map, gt_rpn_map, gt_rpn_object_indices, gt_rpn_background_indices, gt_boxes): | |
""" | |
Performs one training step on a sample of data. | |
Parameters | |
---------- | |
optimizer : torch.optim.Optimizer | |
Optimizer. | |
image_data : torch.Tensor | |
A tensor of shape (batch_size, channels, height, width) representing | |
images normalized using the VGG-16 convention (BGR, ImageNet channel-wise | |
mean-centered). | |
anchor_map : torch.Tensor | |
Map of anchors, shaped (height, width, num_anchors * 4). The last | |
dimension contains the anchor boxes specified as a 4-tuple of | |
(center_y, center_x, height, width), repeated for all anchors at that | |
coordinate of the feature map. If this or anchor_valid_map is not | |
provided, both will be computed here. | |
anchor_valid_map : torch.Tensor | |
Map indicating which anchors are valid (do not intersect image bounds), | |
shaped (height, width). If this or anchor_map is not provided, both will | |
be computed here. | |
gt_rpn_map : torch.Tensor | |
Ground truth RPN map of shape | |
(batch_size, height, width, num_anchors, 6), where height and width are | |
the feature map dimensions, not the input image dimensions. The final | |
dimension contains: | |
- 0: Trainable anchor (1) or not (0). Only valid and non-neutral (that | |
is, definitely positive or negative) anchors are trainable. This is | |
the same as anchor_valid_map with additional invalid anchors caused | |
by neutral samples | |
- 1: For trainable anchors, whether the anchor is an object anchor (1) | |
or background anchor (0). For non-trainable anchors, will be 0. | |
- 2: Regression target for box center, ty. | |
- 3: Regression target for box center, tx. | |
- 4: Regression target for box size, th. | |
- 5: Regression target for box size, tw. | |
gt_rpn_object_indices : List[np.ndarray] | |
For each image in the batch, a map of shape (N, 3) of indices (y, x, k) | |
of all N object anchors in the RPN ground truth map. | |
gt_rpn_background_indices : List[np.ndarray] | |
For each image in the batch, a map of shape (M, 3) of indices of all M | |
background anchors in the RPN ground truth map. | |
gt_boxes : List[List[datasets.training_sample.Box]] | |
For each image in the batch, a list of ground truth object boxes. | |
Returns | |
------- | |
Loss | |
Loss (a dataclass with class and regression losses for both the RPN and | |
detector states). | |
""" | |
self.train() | |
# Clear accumulated gradient | |
optimizer.zero_grad() | |
# For now, we only support a batch size of 1 | |
assert image_data.shape[0] == 1, "Batch size must be 1" | |
assert len(gt_rpn_map.shape) == 5 and gt_rpn_map.shape[0] == 1, "Batch size must be 1" | |
assert len(gt_rpn_object_indices) == 1, "Batch size must be 1" | |
assert len(gt_rpn_background_indices) == 1, "Batch size must be 1" | |
assert len(gt_boxes) == 1, "Batch size must be 1" | |
image_shape = image_data.shape[1:] | |
# Stage 1: Extract features | |
feature_map = self._stage1_feature_extractor(image_data = image_data) | |
# Stage 2: Generate object proposals using RPN | |
rpn_score_map, rpn_box_deltas_map, proposals = self._stage2_region_proposal_network( | |
feature_map = feature_map, | |
image_shape = image_shape, # each image in batch has identical shape: (num_channels, height, width) | |
anchor_map = anchor_map, | |
anchor_valid_map = anchor_valid_map, | |
max_proposals_pre_nms = 12000, | |
max_proposals_post_nms = 2000 | |
) | |
# Sample random mini-batch of anchors (for RPN training) | |
gt_rpn_minibatch_map = self._sample_rpn_minibatch( | |
rpn_map = gt_rpn_map, | |
object_indices = gt_rpn_object_indices, | |
background_indices = gt_rpn_background_indices | |
) | |
# Assign labels to proposals and take random sample (for detector training) | |
proposals, gt_classes, gt_box_deltas = self._label_proposals( | |
proposals = proposals, | |
gt_boxes = gt_boxes[0], # for now, batch size of 1 | |
min_background_iou_threshold = 0.0, | |
min_object_iou_threshold = 0.5 | |
) | |
proposals, gt_classes, gt_box_deltas = self._sample_proposals( | |
proposals = proposals, | |
gt_classes = gt_classes, | |
gt_box_deltas = gt_box_deltas, | |
max_proposals = self._proposal_batch_size, | |
positive_fraction = 0.25 | |
) | |
# Make sure RoI proposals and ground truths are detached from computational | |
# graph so that gradients are not propagated through them. They are treated | |
# as constant inputs into the detector stage. | |
proposals = proposals.detach() | |
gt_classes = gt_classes.detach() | |
gt_box_deltas = gt_box_deltas.detach() | |
# Stage 3: Detector | |
detector_classes, detector_box_deltas = self._stage3_detector_network( | |
feature_map = feature_map, | |
proposals = proposals | |
) | |
# Compute losses | |
rpn_class_loss = _rpn_class_loss(predicted_scores = rpn_score_map, y_true = gt_rpn_minibatch_map) | |
rpn_regression_loss = _rpn_regression_loss(predicted_box_deltas = rpn_box_deltas_map, y_true = gt_rpn_minibatch_map) | |
detector_class_loss = _detector_class_loss(predicted_classes = detector_classes, y_true = gt_classes) | |
detector_regression_loss = _detector_regression_loss(predicted_box_deltas = detector_box_deltas, y_true = gt_box_deltas) | |
total_loss = rpn_class_loss + rpn_regression_loss + detector_class_loss + detector_regression_loss | |
loss = FasterRCNNModel.Loss( | |
rpn_class = rpn_class_loss.detach().cpu().item(), | |
rpn_regression = rpn_regression_loss.detach().cpu().item(), | |
detector_class = detector_class_loss.detach().cpu().item(), | |
detector_regression = detector_regression_loss.detach().cpu().item(), | |
total = total_loss.detach().cpu().item() | |
) | |
# Backprop | |
total_loss.backward() | |
# Optimizer step | |
optimizer.step() | |
# Return losses and data useful for computing statistics | |
return loss | |
def _sample_rpn_minibatch(self, rpn_map, object_indices, background_indices): | |
""" | |
Selects anchors for training and produces a copy of the RPN ground truth | |
map with only those anchors marked as trainable. | |
Parameters | |
---------- | |
rpn_map : np.ndarray | |
RPN ground truth map of shape | |
(batch_size, height, width, num_anchors, 6). | |
object_indices : List[np.ndarray] | |
For each image in the batch, a map of shape (N, 3) of indices (y, x, k) | |
of all N object anchors in the RPN ground truth map. | |
background_indices : List[np.ndarray] | |
For each image in the batch, a map of shape (M, 3) of indices of all M | |
background anchors in the RPN ground truth map. | |
Returns | |
------- | |
np.ndarray | |
A copy of the RPN ground truth map with index 0 of the last dimension | |
recomputed to include only anchors in the minibatch. | |
""" | |
assert rpn_map.shape[0] == 1, "Batch size must be 1" | |
assert len(object_indices) == 1, "Batch size must be 1" | |
assert len(background_indices) == 1, "Batch size must be 1" | |
positive_anchors = object_indices[0] | |
negative_anchors = background_indices[0] | |
assert len(positive_anchors) + len(negative_anchors) >= self._rpn_minibatch_size, "Image has insufficient anchors for RPN minibatch size of %d" % self._rpn_minibatch_size | |
assert len(positive_anchors) > 0, "Image does not have any positive anchors" | |
assert self._rpn_minibatch_size % 2 == 0, "RPN minibatch size must be evenly divisible" | |
# Sample, producing indices into the index maps | |
num_positive_anchors = len(positive_anchors) | |
num_negative_anchors = len(negative_anchors) | |
num_positive_samples = min(self._rpn_minibatch_size // 2, num_positive_anchors) # up to half the samples should be positive, if possible | |
num_negative_samples = self._rpn_minibatch_size - num_positive_samples # the rest should be negative | |
positive_anchor_idxs = random.sample(range(num_positive_anchors), num_positive_samples) | |
negative_anchor_idxs = random.sample(range(num_negative_anchors), num_negative_samples) | |
# Construct index expressions into RPN map | |
positive_anchors = positive_anchors[positive_anchor_idxs] | |
negative_anchors = negative_anchors[negative_anchor_idxs] | |
trainable_anchors = np.concatenate([ positive_anchors, negative_anchors ]) | |
batch_idxs = np.zeros(len(trainable_anchors)) | |
trainable_idxs = (batch_idxs, trainable_anchors[:,0], trainable_anchors[:,1], trainable_anchors[:,2], 0) | |
# Create a copy of the RPN map with samples set as trainable | |
rpn_minibatch_map = rpn_map.clone() | |
rpn_minibatch_map[:,:,:,:,0] = 0 | |
rpn_minibatch_map[trainable_idxs] = 1 | |
return rpn_minibatch_map | |
def _label_proposals(self, proposals, gt_boxes, min_background_iou_threshold, min_object_iou_threshold): | |
""" | |
Determines which proposals generated by the RPN stage overlap with ground | |
truth boxes and creates ground truth labels for the subsequent detector | |
stage. | |
Parameters | |
---------- | |
proposals : torch.Tensor | |
Proposal corners, shaped (N, 4). | |
gt_boxes : List[datasets.training_sample.Box] | |
Ground truth object boxes. | |
min_background_iou_threshold : float | |
Minimum IoU threshold with ground truth boxes below which proposals are | |
ignored entirely. Proposals with an IoU threshold in the range | |
[min_background_iou_threshold, min_object_iou_threshold) are labeled as | |
background. This value can be greater than 0, which has the effect of | |
selecting more difficult background examples that have some degree of | |
overlap with ground truth boxes. | |
min_object_iou_threshold : float | |
Minimum IoU threshold for a proposal to be labeled as an object. | |
Returns | |
------- | |
torch.Tensor, torch.Tensor, torch.Tensor | |
Proposals, (N, 4), labeled as either objects or background (depending on | |
IoU thresholds, some proposals can end up as neither and are excluded | |
here); one-hot encoded class labels, (N, num_classes), for each proposal; | |
and box delta regression targets, (N, 2, (num_classes - 1) * 4), for each | |
proposal. Box delta target values are present at locations [:,1,:] and | |
consist of (ty, tx, th, tw) for the class that the box corresponds to. | |
The entries for all other classes and the background classes should be | |
ignored. A mask is written to locations [:,0,:]. For each proposal | |
assigned a non-background class, there will be 4 consecutive elements | |
marked with 1 indicating the corresponding box delta target values are to | |
be used. There are no box delta regression targets for background | |
proposals and the mask is entirely 0 for those proposals. | |
""" | |
assert min_background_iou_threshold < min_object_iou_threshold, "Object threshold must be greater than background threshold" | |
# Convert ground truth box corners to (M,4) tensor and class indices to (M,) | |
gt_box_corners = np.array([ box["corners"] for box in gt_boxes ], dtype = np.float32) | |
gt_box_corners = t.from_numpy(gt_box_corners).cuda() | |
gt_box_class_idxs = t.tensor([ box["class_index"] for box in gt_boxes ], dtype = t.long, device = "cuda") | |
# Let's be crafty and create some fake proposals that match the ground | |
# truth boxes exactly. This isn't strictly necessary and the model should | |
# work without it but it will help training and will ensure that there are | |
# always some positive examples to train on. | |
proposals = t.vstack([ proposals, gt_box_corners ]) | |
# Compute IoU between each proposal (N,4) and each ground truth box (M,4) | |
# -> (N, M) | |
ious = t_intersection_over_union(boxes1 = proposals, boxes2 = gt_box_corners) | |
# Find the best IoU for each proposal, the class of the ground truth box | |
# associated with it, and the box corners | |
best_ious = t.max(ious, dim = 1).values # (N,) of maximum IoUs for each of the N proposals | |
box_idxs = t.argmax(ious, dim = 1) # (N,) of ground truth box index for each proposal | |
gt_box_class_idxs = gt_box_class_idxs[box_idxs] # (N,) of class indices of highest-IoU box for each proposal | |
gt_box_corners = gt_box_corners[box_idxs] # (N,4) of box corners of highest-IoU box for each proposal | |
# Remove all proposals whose best IoU is less than the minimum threshold | |
# for a negative (background) sample. We also check for IoUs > 0 because | |
# due to earlier clipping, we may get invalid 0-area proposals. | |
idxs = t.where((best_ious >= min_background_iou_threshold))[0] # keep proposals w/ sufficiently high IoU | |
proposals = proposals[idxs] | |
best_ious = best_ious[idxs] | |
gt_box_class_idxs = gt_box_class_idxs[idxs] | |
gt_box_corners = gt_box_corners[idxs] | |
# IoUs less than min_object_iou_threshold will be labeled as background | |
gt_box_class_idxs[best_ious < min_object_iou_threshold] = 0 | |
# One-hot encode class labels | |
num_proposals = proposals.shape[0] | |
gt_classes = t.zeros((num_proposals, self._num_classes), dtype = t.float32, device = "cuda") # (N,num_classes) | |
gt_classes[ t.arange(num_proposals), gt_box_class_idxs ] = 1.0 | |
# Convert proposals and ground truth boxes into "anchor" format (center | |
# points and side lengths). For the detector stage, the proposals serve as | |
# the anchors relative to which the final box predictions will be | |
# regressed. | |
proposal_centers = 0.5 * (proposals[:,0:2] + proposals[:,2:4]) # center_y, center_x | |
proposal_sides = proposals[:,2:4] - proposals[:,0:2] # height, width | |
gt_box_centers = 0.5 * (gt_box_corners[:,0:2] + gt_box_corners[:,2:4]) # center_y, center_x | |
gt_box_sides = gt_box_corners[:,2:4] - gt_box_corners[:,0:2] # height, width | |
# Compute box delta regression targets (ty, tx, th, tw) for each proposal | |
# based on the best box selected | |
box_delta_targets = t.empty((num_proposals, 4), dtype = t.float32, device = "cuda") # (N,4) | |
box_delta_targets[:,0:2] = (gt_box_centers - proposal_centers) / proposal_sides # ty = (gt_center_y - proposal_center_y) / proposal_height, tx = (gt_center_x - proposal_center_x) / proposal_width | |
box_delta_targets[:,2:4] = t.log(gt_box_sides / proposal_sides) # th = log(gt_height / proposal_height), tw = (gt_width / proposal_width) | |
box_delta_means = t.tensor(self._detector_box_delta_means, dtype = t.float32, device = "cuda") | |
box_delta_stds = t.tensor(self._detector_box_delta_stds, dtype = t.float32, device = "cuda") | |
box_delta_targets[:,:] -= box_delta_means # mean adjustment | |
box_delta_targets[:,:] /= box_delta_stds # standard deviation scaling | |
# Convert regression targets into a map of shape (N,2,4*(C-1)) where C is | |
# the number of classes and [:,0,:] specifies a mask for the corresponding | |
# target components at [:,1,:]. Targets are ordered (ty, tx, th, tw). | |
# Background class 0 is not present at all. | |
gt_box_deltas = t.zeros((num_proposals, 2, 4 * (self._num_classes - 1)), dtype = t.float32, device = "cuda") | |
gt_box_deltas[:,0,:] = t.repeat_interleave(gt_classes, repeats = 4, dim = 1)[:,4:] # create masks using interleaved repetition, remembering to ignore class 0 | |
gt_box_deltas[:,1,:] = t.tile(box_delta_targets, dims = (1, self._num_classes - 1)) # populate regression targets with straightforward repetition (only those columns corresponding to class are masked on) | |
return proposals, gt_classes, gt_box_deltas | |
def _sample_proposals(self, proposals, gt_classes, gt_box_deltas, max_proposals, positive_fraction): | |
if max_proposals <= 0: | |
return proposals, gt_classes, gt_box_deltas | |
# Get positive and negative (background) proposals | |
class_indices = t.argmax(gt_classes, axis = 1) # (N,num_classes) -> (N,), where each element is the class index (highest score from its row) | |
positive_indices = t.where(class_indices > 0)[0] | |
negative_indices = t.where(class_indices <= 0)[0] | |
num_positive_proposals = len(positive_indices) | |
num_negative_proposals = len(negative_indices) | |
# Select positive and negative samples, if there are enough. Note that the | |
# number of positive samples can be either the positive fraction of the | |
# *actual* number of proposals *or* the *desired* number (max_proposals). | |
# In practice, these yield virtually identical results but the latter | |
# method will yield slightly more positive samples in the rare cases when | |
# the number of proposals is below the desired number. Here, we use the | |
# former method but others, such as Yun Chen, use the latter. To implement | |
# it, replace num_samples with max_proposals in the line that computes | |
# num_positive_samples. I am not sure what the original Faster R-CNN | |
# implementation does. | |
num_samples = min(max_proposals, len(class_indices)) | |
num_positive_samples = min(round(num_samples * positive_fraction), num_positive_proposals) | |
num_negative_samples = min(num_samples - num_positive_samples, num_negative_proposals) | |
# Do we have enough? | |
if num_positive_samples <= 0 or num_negative_samples <= 0: | |
return proposals[[]], gt_classes[[]], gt_box_deltas[[]] # return 0-length tensors | |
# Sample randomly | |
positive_sample_indices = positive_indices[ t.randperm(len(positive_indices))[0:num_positive_samples] ] | |
negative_sample_indices = negative_indices[ t.randperm(len(negative_indices))[0:num_negative_samples] ] | |
indices = t.cat([ positive_sample_indices, negative_sample_indices ]) | |
# Return | |
return proposals[indices], gt_classes[indices], gt_box_deltas[indices] | |
class RegionProposalNetwork(nn.Module): | |
def __init__(self, allow_edge_proposals = False): | |
super().__init__() | |
# Constants | |
self._allow_edge_proposals = allow_edge_proposals | |
# Layers | |
num_anchors = 9 | |
self._rpn_conv1 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._rpn_class = nn.Conv2d(in_channels = 512, out_channels = num_anchors, kernel_size = (1, 1), stride = 1, padding = "same") | |
self._rpn_boxes = nn.Conv2d(in_channels = 512, out_channels = num_anchors * 4, kernel_size = (1, 1), stride = 1, padding = "same") | |
# Initialize weights | |
self._rpn_conv1.weight.data.normal_(mean = 0.0, std = 0.01) | |
self._rpn_conv1.bias.data.zero_() | |
self._rpn_class.weight.data.normal_(mean = 0.0, std = 0.01) | |
self._rpn_class.bias.data.zero_() | |
self._rpn_boxes.weight.data.normal_(mean = 0.0, std = 0.01) | |
self._rpn_boxes.bias.data.zero_() | |
def forward(self, feature_map, image_shape, anchor_map, anchor_valid_map, max_proposals_pre_nms, max_proposals_post_nms): | |
""" | |
Predict objectness scores and regress region-of-interest box proposals on | |
an input feature map. | |
Parameters | |
---------- | |
feature_map : torch.Tensor | |
Feature map of shape (batch_size, 512, height, width). | |
image_shape : Tuple[int, int, int] | |
Shapes of each image in pixels: (num_channels, height, width). | |
anchor_map : np.ndarray | |
Map of anchors, shaped (height, width, num_anchors * 4). The last | |
dimension contains the anchor boxes specified as a 4-tuple of | |
(center_y, center_x, height, width), repeated for all anchors at that | |
coordinate of the feature map. | |
anchor_valid_map : np.ndarray | |
Map indicating which anchors are valid (do not intersect image bounds), | |
shaped (height, width, num_anchors). | |
max_proposals_pre_nms : int | |
How many of the best proposals (sorted by objectness score) to extract | |
before applying non-maximum suppression. | |
max_proposals_post_nms : int | |
How many of the best proposals (sorted by objectness score) to keep after | |
non-maximum suppression. | |
Returns | |
------- | |
torch.Tensor, torch.Tensor, torch.Tensor | |
- Objectness scores (batch_size, height, width, num_anchors) | |
- Box regressions (batch_size, height, width, num_anchors * 4), as box | |
deltas (that is, (ty, tx, th, tw) for each anchor) | |
- Proposals (N, 4) -- all corresponding proposal box corners stored as | |
(y1, x1, y2, x2). | |
""" | |
# Pass through the network | |
y = F.relu(self._rpn_conv1(feature_map)) | |
objectness_score_map = t.sigmoid(self._rpn_class(y)) | |
box_deltas_map = self._rpn_boxes(y) | |
# Transpose shapes to be more convenient: | |
# objectness_score_map -> (batch_size, height, width, num_anchors) | |
# box_deltas_map -> (batch_size, height, width, num_anchors * 4) | |
objectness_score_map = objectness_score_map.permute(0, 2, 3, 1).contiguous() | |
box_deltas_map = box_deltas_map.permute(0, 2, 3, 1).contiguous() | |
# Returning to CPU land by extracting proposals as lists (NumPy arrays) | |
anchors, objectness_scores, box_deltas = self._extract_valid( | |
anchor_map = anchor_map, | |
anchor_valid_map = anchor_valid_map, | |
objectness_score_map = objectness_score_map, | |
box_deltas_map = box_deltas_map | |
) | |
# **** UNCOMMENT THE LINE BELOW TO "FIX" MEMORY LEAK **** | |
# Detach from graph to avoid backprop. According to my understanding, this | |
# should be redundant here because we later take care to detach the | |
# proposals (in FasterRCNNModel). However, there is a memory leak involving | |
# t_convert_deltas_to_boxes() if this is not done here. Ultimately, the | |
# numerical results are not affected. Proposals returned from this function | |
# are supposed to be constant and are fed into the detector stage. See any | |
# commit prior to 209141c for an earlier version of the code here that | |
# performed all operations on CPU using NumPy, which was slightly slower | |
# but equivalent. | |
#box_deltas = box_deltas.detach() | |
# Convert regressions to box corners | |
proposals = t_convert_deltas_to_boxes( | |
box_deltas = box_deltas, | |
anchors = t.from_numpy(anchors).cuda(), | |
box_delta_means = t.tensor([0, 0, 0, 0], dtype = t.float32, device = "cuda"), | |
box_delta_stds = t.tensor([1, 1, 1, 1], dtype = t.float32, device = "cuda") | |
) | |
# Keep only the top-N scores. Note that we do not care whether the | |
# proposals were labeled as objects (score > 0.5) and peform a simple | |
# ranking among all of them. Restricting them has a strong adverse impact | |
# on training performance. | |
sorted_indices = t.argsort(objectness_scores) # sort in ascending order of objectness score | |
sorted_indices = sorted_indices.flip(dims = (0,)) # descending order of score | |
proposals = proposals[sorted_indices][0:max_proposals_pre_nms] # grab the top-N best proposals | |
objectness_scores = objectness_scores[sorted_indices][0:max_proposals_pre_nms] # corresponding scores | |
# Clip to image boundaries | |
proposals[:,0:2] = t.clamp(proposals[:,0:2], min = 0) | |
proposals[:,2] = t.clamp(proposals[:,2], max = image_shape[1]) | |
proposals[:,3] = t.clamp(proposals[:,3], max = image_shape[2]) | |
# Remove anything less than 16 pixels on a side | |
height = proposals[:,2] - proposals[:,0] | |
width = proposals[:,3] - proposals[:,1] | |
idxs = t.where((height >= 16) & (width >= 16))[0] | |
proposals = proposals[idxs] | |
objectness_scores = objectness_scores[idxs] | |
# Perform NMS | |
idxs = nms( | |
boxes = proposals, | |
scores = objectness_scores, | |
iou_threshold = 0.7 | |
) | |
idxs = idxs[0:max_proposals_post_nms] | |
proposals = proposals[idxs] | |
# Return network outputs as PyTorch tensors and extracted object proposals | |
# as NumPy arrays | |
return objectness_score_map, box_deltas_map, proposals | |
def _extract_valid(self, anchor_map, anchor_valid_map, objectness_score_map, box_deltas_map): | |
assert objectness_score_map.shape[0] == 1 # only batch size of 1 supported for now | |
height, width, num_anchors = anchor_valid_map.shape | |
anchors = anchor_map.reshape((height * width * num_anchors, 4)) # [N,4] all anchors | |
anchors_valid = anchor_valid_map.reshape((height * width * num_anchors)) # [N,] whether anchors are valid (i.e., do not cross image boundaries) | |
scores = objectness_score_map.reshape((height * width * num_anchors)) # [N,] prediced objectness scores | |
box_deltas = box_deltas_map.reshape((height * width * num_anchors, 4)) # [N,4] predicted box delta regression targets | |
if self._allow_edge_proposals: | |
# Use all proposals | |
return anchors, scores, box_deltas | |
else: | |
# Filter out those proposals generated at invalid anchors | |
idxs = anchors_valid > 0 | |
return anchors[idxs], scores[idxs], box_deltas[idxs] | |
def _rpn_class_loss(predicted_scores, y_true): | |
""" | |
Computes RPN class loss. | |
Parameters | |
---------- | |
predicted_scores : torch.Tensor | |
A tensor of shape (batch_size, height, width, num_anchors) containing | |
objectness scores (0 = background, 1 = object). | |
y_true : torch.Tensor | |
Ground truth tensor of shape (batch_size, height, width, num_anchors, 6). | |
Returns | |
------- | |
torch.Tensor | |
Scalar loss. | |
""" | |
epsilon = 1e-7 | |
# y_true_class: (batch_size, height, width, num_anchors), same as predicted_scores | |
y_true_class = y_true[:,:,:,:,1].reshape(predicted_scores.shape) | |
y_predicted_class = predicted_scores | |
# y_mask: y_true[:,:,:,0] is 1.0 for anchors included in the mini-batch | |
y_mask = y_true[:,:,:,:,0].reshape(predicted_scores.shape) | |
# Compute how many anchors are actually used in the mini-batch (e.g., | |
# typically 256) | |
N_cls = t.count_nonzero(y_mask) + epsilon | |
# Compute element-wise loss for all anchors | |
loss_all_anchors = F.binary_cross_entropy(input = y_predicted_class, target = y_true_class, reduction = "none") | |
# Zero out the ones which should not have been included | |
relevant_loss_terms = y_mask * loss_all_anchors | |
# Sum the total loss and normalize by the number of anchors used | |
return t.sum(relevant_loss_terms) / N_cls | |
def _rpn_regression_loss(predicted_box_deltas, y_true): | |
""" | |
Computes RPN box delta regression loss. | |
Parameters | |
---------- | |
predicted_box_deltas : torch.Tensor | |
A tensor of shape (batch_size, height, width, num_anchors * 4) containing | |
RoI box delta regressions for each anchor, stored as: ty, tx, th, tw. | |
y_true : torch.Tensor | |
Ground truth tensor of shape (batch_size, height, width, num_anchors, 6). | |
Returns | |
------- | |
torch.Tensor | |
Scalar loss. | |
""" | |
epsilon = 1e-7 | |
scale_factor = 1.0 # hyper-parameter that controls magnitude of regression loss and is chosen to make regression term comparable to class term | |
sigma = 3.0 # see: https://github.com/rbgirshick/py-faster-rcnn/issues/89 | |
sigma_squared = sigma * sigma | |
y_predicted_regression = predicted_box_deltas | |
y_true_regression = y_true[:,:,:,:,2:6].reshape(y_predicted_regression.shape) | |
# Include only anchors that are used in the mini-batch and which correspond | |
# to objects (positive samples) | |
y_included = y_true[:,:,:,:,0].reshape(y_true.shape[0:4]) # trainable anchors map: (batch_size, height, width, num_anchors) | |
y_positive = y_true[:,:,:,:,1].reshape(y_true.shape[0:4]) # positive anchors | |
y_mask = y_included * y_positive | |
# y_mask is of the wrong shape. We have one value per (y,x,k) position but in | |
# fact need to have 4 values (one for each of the regression variables). For | |
# example, y_predicted might be (1,37,50,36) and y_mask will be (1,37,50,9). | |
# We need to repeat the last dimension 4 times. | |
y_mask = y_mask.repeat_interleave(repeats = 4, dim = 3) | |
# The paper normalizes by dividing by a quantity called N_reg, which is equal | |
# to the total number of anchors (~2400) and then multiplying by lambda=10. | |
# This does not make sense to me because we are summing over a mini-batch at | |
# most, so we use N_cls here. I might be misunderstanding what is going on | |
# but 10/2400 = 1/240 which is pretty close to 1/256 and the paper mentions | |
# that training is relatively insensitve to choice of normalization. | |
N_cls = t.count_nonzero(y_included) + epsilon | |
# Compute element-wise loss using robust L1 function for all 4 regression | |
# components | |
x = y_true_regression - y_predicted_regression | |
x_abs = t.abs(x) | |
is_negative_branch = (x_abs < (1.0 / sigma_squared)).float() | |
R_negative_branch = 0.5 * x * x * sigma_squared | |
R_positive_branch = x_abs - 0.5 / sigma_squared | |
loss_all_anchors = is_negative_branch * R_negative_branch + (1.0 - is_negative_branch) * R_positive_branch | |
# Zero out the ones which should not have been included | |
relevant_loss_terms = y_mask * loss_all_anchors | |
return scale_factor * t.sum(relevant_loss_terms) / N_cls | |
class DetectorNetwork(nn.Module): | |
def __init__(self, num_classes, dropout_probability): | |
super().__init__() | |
# Define network | |
self._roi_pool = RoIPool(output_size = (7, 7), spatial_scale = 1.0 / 16.0) | |
self._fc1 = nn.Linear(in_features = 512*7*7, out_features = 4096) | |
self._fc2 = nn.Linear(in_features = 4096, out_features = 4096) | |
self._classifier = nn.Linear(in_features = 4096, out_features = num_classes) | |
self._regressor = nn.Linear(in_features = 4096, out_features = (num_classes - 1) * 4) | |
# Dropout layers | |
self._dropout1 = nn.Dropout(p = dropout_probability) | |
self._dropout2 = nn.Dropout(p = dropout_probability) | |
# Initialize weights | |
self._classifier.weight.data.normal_(mean = 0.0, std = 0.01) | |
self._classifier.bias.data.zero_() | |
self._regressor.weight.data.normal_(mean = 0.0, std = 0.001) | |
self._regressor.bias.data.zero_() | |
def forward(self, feature_map, proposals): | |
""" | |
Predict final class and box delta regressions for region-of-interest | |
proposals. The proposals serve as "anchors" for the box deltas, which | |
refine the proposals into final boxes. | |
Parameters | |
---------- | |
feature_map : torch.Tensor | |
Feature map of shape (batch_size, 512, height, width). | |
proposals : torch.Tensor | |
Region-of-interest box proposals that are likely to contain objects. | |
Has shape (N, 4), where N is the number of proposals, with each box given | |
as (y1, x1, y2, x2) in pixel coordinates. | |
Returns | |
------- | |
torch.Tensor, torch.Tensor | |
Predicted classes, (N, num_classes), encoded as a one-hot vector, and | |
predicted box delta regressions, (N, 4*(num_classes-1)), where the deltas | |
are expressed as (ty, tx, th, tw) and are relative to each corresponding | |
proposal box. Because there is no box for the background class 0, it is | |
excluded entirely and only (num_classes-1) sets of box delta targets are | |
computed. | |
""" | |
# Batch size of one for now, so no need to associate proposals with batches | |
assert feature_map.shape[0] == 1, "Batch size must be 1" | |
batch_idxs = t.zeros((proposals.shape[0], 1)).cuda() | |
# (N, 5) tensor of (batch_idx, x1, y1, x2, y2) | |
indexed_proposals = t.cat([ batch_idxs, proposals ], dim = 1) | |
indexed_proposals = indexed_proposals[:, [ 0, 2, 1, 4, 3 ]] # each row, (batch_idx, y1, x1, y2, x2) -> (batch_idx, x1, y1, x2, y2) | |
# RoI pooling: (N, 512, 7, 7) | |
rois = self._roi_pool(feature_map, indexed_proposals) | |
rois = rois.reshape((rois.shape[0], 512*7*7)) # flatten each RoI: (N, 512*7*7) | |
# Forward propagate | |
y1o = F.relu(self._fc1(rois)) | |
y1 = self._dropout1(y1o) | |
y2o = F.relu(self._fc2(y1)) | |
y2 = self._dropout2(y2o) | |
classes_raw = self._classifier(y2) | |
classes = F.softmax(classes_raw, dim = 1) | |
box_deltas = self._regressor(y2) | |
return classes, box_deltas | |
def _detector_class_loss(predicted_classes, y_true): | |
""" | |
Computes detector class loss. | |
Parameters | |
---------- | |
predicted_classes : torch.Tensor | |
RoI predicted classes as categorical vectors, (N, num_classes). | |
y_true : torch.Tensor | |
RoI class labels as categorical vectors, (N, num_classes). | |
Returns | |
------- | |
torch.Tensor | |
Scalar loss. | |
""" | |
epsilon = 1e-7 | |
scale_factor = 1.0 | |
cross_entropy_per_row = -(y_true * t.log(predicted_classes + epsilon)).sum(dim = 1) | |
N = cross_entropy_per_row.shape[0] + epsilon | |
cross_entropy = t.sum(cross_entropy_per_row) / N | |
return scale_factor * cross_entropy | |
def _detector_regression_loss(predicted_box_deltas, y_true): | |
""" | |
Computes detector regression loss. | |
Parameters | |
---------- | |
predicted_box_deltas : torch.Tensor | |
RoI predicted box delta regressions, (N, 4*(num_classes-1)). The background | |
class is excluded and only the non-background classes are included. Each | |
set of box deltas is stored in parameterized form as (ty, tx, th, tw). | |
y_true : torch.Tensor | |
RoI box delta regression ground truth labels, (N, 2, 4*(num_classes-1)). | |
These are stored as mask values (1 or 0) in (:,0,:) and regression | |
parameters in (:,1,:). Note that it is important to mask off the predicted | |
and ground truth values because they may be set to invalid values. | |
Returns | |
------- | |
torch.Tensor | |
Scalar loss. | |
""" | |
epsilon = 1e-7 | |
scale_factor = 1.0 | |
sigma = 1.0 | |
sigma_squared = sigma * sigma | |
# We want to unpack the regression targets and the mask of valid targets into | |
# tensors each of the same shape as the predicted: | |
# (num_proposals, 4*(num_classes-1)) | |
# y_true has shape: | |
# (num_proposals, 2, 4*(num_classes-1)) | |
y_mask = y_true[:,0,:] | |
y_true_targets = y_true[:,1,:] | |
# Compute element-wise loss using robust L1 function for all 4 regression | |
# targets | |
x = y_true_targets - predicted_box_deltas | |
x_abs = t.abs(x) | |
is_negative_branch = (x < (1.0 / sigma_squared)).float() | |
R_negative_branch = 0.5 * x * x * sigma_squared | |
R_positive_branch = x_abs - 0.5 / sigma_squared | |
losses = is_negative_branch * R_negative_branch + (1.0 - is_negative_branch) * R_positive_branch | |
# Normalize to number of proposals (e.g., 128). Although this may not be | |
# what the paper does, it seems to work. Other implemetnations do this. | |
# Using e.g., the number of positive proposals will cause the loss to | |
# behave erratically because sometimes N will become very small. | |
N = y_true.shape[0] + epsilon | |
relevant_loss_terms = y_mask * losses | |
return scale_factor * t.sum(relevant_loss_terms) / N | |
class FeatureExtractor(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._block1_conv1 = nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block1_conv2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block1_pool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2) | |
self._block2_conv1 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block2_conv2 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block2_pool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2) | |
self._block3_conv1 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block3_conv2 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block3_conv3 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block3_pool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2) | |
self._block4_conv1 = nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block4_conv2 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block4_conv3 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block4_pool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2) | |
self._block5_conv1 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block5_conv2 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = 1, padding = "same") | |
self._block5_conv3 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = 1, padding = "same") | |
# Freeze first two convolutional blocks | |
self._block1_conv1.weight.requires_grad = False | |
self._block1_conv1.bias.requires_grad = False | |
self._block1_conv2.weight.requires_grad = False | |
self._block1_conv2.bias.requires_grad = False | |
self._block2_conv1.weight.requires_grad = False | |
self._block2_conv1.bias.requires_grad = False | |
self._block2_conv2.weight.requires_grad = False | |
self._block2_conv2.bias.requires_grad = False | |
def forward(self, image_data): | |
""" | |
Converts input images into feature maps using VGG-16 convolutional layers. | |
Parameters | |
---------- | |
image_data : torch.Tensor | |
A tensor of shape (batch_size, channels, height, width) representing | |
images normalized using the VGG-16 convention (BGR, ImageNet channel-wise | |
mean-centered). | |
Returns | |
------- | |
torch.Tensor | |
Feature map of shape (batch_size, 512, height // 16, width // 16). | |
""" | |
y = F.relu(self._block1_conv1(image_data)) | |
y = F.relu(self._block1_conv2(y)) | |
y = self._block1_pool(y) | |
y = F.relu(self._block2_conv1(y)) | |
y = F.relu(self._block2_conv2(y)) | |
y = self._block2_pool(y) | |
y = F.relu(self._block3_conv1(y)) | |
y = F.relu(self._block3_conv2(y)) | |
y = F.relu(self._block3_conv3(y)) | |
y = self._block3_pool(y) | |
y = F.relu(self._block4_conv1(y)) | |
y = F.relu(self._block4_conv2(y)) | |
y = F.relu(self._block4_conv3(y)) | |
y = self._block4_pool(y) | |
y = F.relu(self._block5_conv1(y)) | |
y = F.relu(self._block5_conv2(y)) | |
y = F.relu(self._block5_conv3(y)) | |
return y | |
def t_intersection_over_union(boxes1, boxes2): | |
""" | |
Equivalent to intersection_over_union(), operating on PyTorch tensors. | |
Parameters | |
---------- | |
boxes1 : torch.Tensor | |
Box corners, shaped (N, 4), with each box as (y1, x1, y2, x2). | |
boxes2 : torch.Tensor | |
Box corners, shaped (M, 4). | |
Returns | |
------- | |
torch.Tensor | |
IoUs for each pair of boxes in boxes1 and boxes2, shaped (N, M). | |
""" | |
top_left_point = t.maximum(boxes1[:,None,0:2], boxes2[:,0:2]) # (N,1,2) and (M,2) -> (N,M,2) indicating top-left corners of box pairs | |
bottom_right_point = t.minimum(boxes1[:,None,2:4], boxes2[:,2:4]) # "" bottom-right corners "" | |
well_ordered_mask = t.all(top_left_point < bottom_right_point, axis = 2) # (N,M) indicating whether top_left_x < bottom_right_x and top_left_y < bottom_right_y (meaning boxes may intersect) | |
intersection_areas = well_ordered_mask * t.prod(bottom_right_point - top_left_point, dim = 2) # (N,M) indicating intersection area (bottom_right_x - top_left_x) * (bottom_right_y - top_left_y) | |
areas1 = t.prod(boxes1[:,2:4] - boxes1[:,0:2], dim = 1) # (N,) indicating areas of boxes1 | |
areas2 = t.prod(boxes2[:,2:4] - boxes2[:,0:2], dim = 1) # (M,) indicating areas of boxes2 | |
union_areas = areas1[:,None] + areas2 - intersection_areas # (N,1) + (M,) - (N,M) = (N,M), union areas of both boxes | |
epsilon = 1e-7 | |
return intersection_areas / (union_areas + epsilon) | |
def t_convert_deltas_to_boxes(box_deltas, anchors, box_delta_means, box_delta_stds): | |
""" | |
Equivalent to convert_deltas_to_boxes(), operating on PyTorch tensors. | |
Parameters | |
---------- | |
box_deltas : torch.Tensor | |
Box deltas with shape (N, 4). Each row is (ty, tx, th, tw). | |
anchors : torch.Tensor | |
Corresponding anchors that the box deltas are based upon, shaped (N, 4) | |
with each row being (center_y, center_x, height, width). | |
box_delta_means : torch.Tensor | |
Mean ajustment to box deltas, (4,), to be added after standard deviation | |
scaling and before conversion to actual box coordinates. | |
box_delta_stds : torch.Tensor | |
Standard deviation adjustment to box deltas, (4,). Box deltas are first | |
multiplied by these values. | |
Returns | |
------- | |
torch.Tensor | |
Box coordinates, (N, 4), with each row being (y1, x1, y2, x2). | |
""" | |
box_deltas = box_deltas * box_delta_stds + box_delta_means | |
center = anchors[:,2:4] * box_deltas[:,0:2] + anchors[:,0:2] # center_x = anchor_width * tx + anchor_center_x, center_y = anchor_height * ty + anchor_center_y | |
size = anchors[:,2:4] * t.exp(box_deltas[:,2:4]) # width = anchor_width * exp(tw), height = anchor_height * exp(th) | |
boxes = t.empty(box_deltas.shape, dtype = t.float32, device = "cuda") | |
boxes[:,0:2] = center - 0.5 * size # y1, x1 | |
boxes[:,2:4] = center + 0.5 * size # y2, x2 | |
return boxes | |
__image_data = np.array([44.060997,44.060997,44.060997,41.060997,39.060997,39.060997,42.060997,30.060997,-20.939003,-59.939003,-88.939,-81.939,-77.939,-74.939,-74.939,-74.939,-76.939,-80.939,-86.939,-91.939,-78.939,-46.939003,-44.939003,-37.939003,-16.939003,11.060997,37.060997,27.060997,21.060997,20.060997,23.060997,26.060997,27.060997,30.060997,33.060997,38.060997,53.060997,78.061,58.060997,37.060997,22.060997,15.060997,11.060997,5.060997,-0.939003,-7.939003,-11.939003,-18.939003,-30.939003,-38.939003,-43.939003,-39.939003,-47.939003,-66.939,-84.939,-96.939,-98.939,-93.939,-86.939,-80.939,-71.939,-59.939003,-54.939003,-49.939003,-44.939003,-42.939003,-41.939003,-46.939003,-47.939003,-44.939003,-41.939003,-39.939003,-44.939003,-42.939003,-37.939003,-24.939003,-8.939003,11.060997,14.060997,14.060997,13.060997,-25.939003,-76.939,-81.939,-82.939,-81.939,-83.939,-86.939,-88.939,-91.939,-93.939,-97.939,-99.939,-98.939,-97.939,-96.939,-95.939,-94.939,-94.939,-94.939,-93.939,-91.939,-92.939,-94.939,-96.939,-96.939,-95.939,-94.939,-95.939,-96.939,-98.939,-98.939,-97.939,-96.939,-95.939,-96.939,-98.939,-100.939,-102.939,-103.939,-102.939,-102.939,-103.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-98.939,-92.939,-97.939,-86.939,-60.939003,-2.939003,34.060997,15.060997,12.060997,11.060997,-31.939003,-56.939003,-62.939003,-56.939003,-50.939003,-48.939003,-45.939003,-40.939003,-28.939003,-25.939003,-32.939003,-32.939003,-31.939003,-30.939003,-30.939003,-32.939003,-31.939003,-31.939003,-32.939003,-35.939003,-36.939003,-35.939003,-38.939003,-43.939003,-47.939003,-48.939003,-47.939003,-56.939003,-64.939,-68.939,-73.939,-78.939,-73.939,-73.939,-78.939,-82.939,-83.939,-82.939,-84.939,-89.939,-88.939,-91.939,-97.939,-95.939,-93.939,-93.939,-94.939,-95.939,-95.939,-92.939,-87.939,-87.939,-88.939,-92.939,-91.939,-88.939,-84.939,-68.939,-41.939003,-24.939003,-10.939003,-2.939003,25.060997,56.060997,41.060997,35.060997,38.060997,41.060997,44.060997,47.060997,45.060997,43.060997,57.060997,54.060997,35.060997,13.060997,11.060997,52.060997,56.060997,46.060997,68.061,72.061,58.060997,54.060997,52.060997,54.060997,51.060997,48.060997,47.060997,44.060997,41.060997,39.060997,36.060997,35.060997,28.060997,21.060997,20.060997,11.060997,-4.939003,-9.939003,-8.939003,1.060997,-2.939003,-12.939003,-17.939003,-20.939003,-19.939003,-2.939003,16.060997,31.060997,29.060997,17.060997,0.06099701,-23.939003,-56.939003,-38.939003,-23.939003,-24.939003,-22.939003,-18.939003,-19.939003,-23.939003,-30.939003,-28.939003,-26.939003,-31.939003,-37.939003,-43.939003,-40.939003,-38.939003,-36.939003,-52.939003,-63.939003,-59.939003,-60.939003,-62.939003,-63.939003,-67.939,-72.939,-76.939,-79.939,-81.939,-81.939,-80.939,-82.939,-84.939,-87.939,-85.939,-87.939,-96.939,-91.939,-75.939,-26.939003,-6.939003,-15.939003,-2.939003,-3.939003,-39.939003,-61.939003,-77.939,-94.939,-100.939,-94.939,-90.939,-91.939,-100.939,-55.939003,8.060997,39.060997,54.060997,55.060997,23.060997,-20.939003,-84.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-98.939,-95.939,-100.939,-101.939,-99.939,-89.939,-81.939,-76.939,-81.939,-84.939,-81.939,-77.939,-74.939,-79.939,-74.939,-59.939003,-64.939,-68.939,-67.939,-64.939,-62.939003,-68.939,-65.939,-54.939003,-54.939003,-53.939003,-52.939003,-46.939003,-40.939003,-46.939003,-45.939003,-38.939003,-36.939003,-36.939003,-41.939003,-47.939003,-53.939003,-56.939003,-52.939003,-41.939003,-37.939003,-34.939003,-32.939003,-28.939003,-24.939003,-25.939003,-30.939003,-40.939003,-39.939003,-37.939003,-39.939003,-44.939003,-52.939003,-57.939003,-59.939003,-58.939003,-53.939003,-52.939003,-57.939003,-61.939003,-64.939,-66.939,-68.939,-71.939,-75.939,-77.939,-72.939,-74.939,-78.939,-79.939,-80.939,-81.939,-82.939,-85.939,-90.939,-90.939,-89.939,-96.939,-93.939,-78.939,-54.939003,-35.939003,-33.939003,-45.939003,-58.939003,-19.939003,15.060997,46.060997,30.060997,9.060997,-10.939003,-25.939003,-35.939003,-28.939003,-25.939003,-25.939003,-12.939003,-4.939003,-9.939003,-4.939003,5.060997,26.060997,36.060997,35.060997,33.060997,32.060997,36.060997,41.060997,45.060997,42.060997,33.060997,17.060997,13.060997,12.060997,16.060997,18.060997,19.060997,17.060997,18.060997,22.060997,24.060997,25.060997,26.060997,26.060997,27.060997,28.060997,28.060997,27.060997,27.060997,28.060997,29.060997,32.060997,34.060997,34.060997,35.060997,36.060997,39.060997,40.060997,37.060997,33.060997,29.060997,32.060997,34.060997,33.060997,35.060997,37.060997,41.060997,42.060997,42.060997,45.060997,45.060997,42.060997,41.060997,40.060997,36.060997,35.060997,34.060997,29.060997,24.060997,20.060997,24.060997,27.060997,29.060997,28.060997,26.060997,24.060997,25.060997,29.060997,27.060997,24.060997,23.060997,22.060997,22.060997,17.060997,16.060997,18.060997,17.060997,15.060997,10.060997,8.060997,8.060997,5.060997,4.060997,5.060997,3.060997,2.060997,0.06099701,-0.939003,-1.939003,-0.939003,-2.939003,-7.939003,-6.939003,2.060997,30.060997,33.060997,26.060997,20.060997,16.060997,13.060997,-1.939003,-20.939003,-42.939003,-44.939003,-39.939003,-64.939,-85.939,-101.939,-103.939,-103.939,-101.939,-102.939,-103.939,-103.939,-67.939,5.060997,25.060997,24.060997,-10.939003,-57.939003,-103.939,-99.939,-97.939,-98.939,-94.939,-91.939,-90.939,-87.939,-83.939,-80.939,-80.939,-81.939,-77.939,-73.939,-68.939,-68.939,-69.939,-69.939,-66.939,-59.939003,-60.939003,-60.939003,-55.939003,-53.939003,-52.939003,-52.939003,-54.939003,-57.939003,-54.939003,-51.939003,-52.939003,-52.939003,-51.939003,-52.939003,-51.939003,-49.939003,-52.939003,-53.939003,-50.939003,-47.939003,-45.939003,-47.939003,-48.939003,-48.939003,-45.939003,-42.939003,-40.939003,-42.939003,-45.939003,-40.939003,-35.939003,-30.939003,-29.939003,-27.939003,-21.939003,-22.939003,-26.939003,-21.939003,-18.939003,-15.939003,-9.939003,-5.939003,-3.939003,6.060997,16.060997,-14.939003,-39.939003,-59.939003,-51.939003,-44.939003,-49.939003,-49.939003,-47.939003,-45.939003,-43.939003,-42.939003,-42.939003,-41.939003,-39.939003,-35.939003,-32.939003,-35.939003,-37.939003,-39.939003,-37.939003,-35.939003,-33.939003,-37.939003,-42.939003,-30.939003,-22.939003,-19.939003,-23.939003,-30.939003,-38.939003,-37.939003,-32.939003,-34.939003,-35.939003,-36.939003,-39.939003,-41.939003,-40.939003,-42.939003,-44.939003,-47.939003,-45.939003,-39.939003,-42.939003,-46.939003,-49.939003,-47.939003,-43.939003,-44.939003,-45.939003,-44.939003,-42.939003,-42.939003,-47.939003,-53.939003,-58.939003,-61.939003,-62.939003,-62.939003,-60.939003,-58.939003,-56.939003,-58.939003,-60.939003,-59.939003,-61.939003,-64.939,-63.939003,-63.939003,-64.939,-65.939,-64.939,-64.939,-67.939,-72.939,-68.939,-66.939,-71.939,-74.939,-75.939,-71.939,-69.939,-70.939,-68.939,-68.939,-69.939,-73.939,-77.939,-78.939,-80.939,-82.939,-81.939,-72.939,-44.939003,-27.939003,-17.939003,-32.939003,-54.939003,-81.939,-83.939,-83.939,-84.939,-87.939,-90.939,-89.939,-89.939,-91.939,-90.939,-89.939,-91.939,-92.939,-94.939,-94.939,-94.939,-95.939,-94.939,-94.939,-93.939,-95.939,-97.939,-98.939,-99.939,-100.939,-95.939,-93.939,-96.939,-99.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-101.939,-101.939,-78.939,-43.939003,3.060997,15.060997,2.060997,-55.939003,-86.939,-102.939,-98.939,-98.939,-103.939,-102.939,-101.939,-103.939,-103.939,-103.939,-102.939,-99.939,-95.939,-95.939,-96.939,-97.939,-96.939,-94.939,-91.939,-90.939,-93.939,-98.939,-100.939,-95.939,-94.939,-94.939,-85.939,-76.939,-68.939,-83.939,-95.939,-93.939,-92.939,-92.939,-92.939,-92.939,-92.939,-85.939,-80.939,47.060997,47.060997,48.060997,46.060997,44.060997,45.060997,49.060997,39.060997,-4.939003,-50.939003,-92.939,-86.939,-81.939,-77.939,-78.939,-78.939,-77.939,-80.939,-85.939,-87.939,-78.939,-58.939003,-50.939003,-40.939003,-27.939003,2.060997,33.060997,24.060997,19.060997,20.060997,21.060997,23.060997,25.060997,26.060997,28.060997,48.060997,67.061,85.061,53.060997,25.060997,18.060997,13.060997,8.060997,1.060997,-4.939003,-10.939003,-12.939003,-17.939003,-29.939003,-46.939003,-63.939003,-65.939,-71.939,-82.939,-91.939,-97.939,-98.939,-95.939,-90.939,-83.939,-75.939,-66.939,-67.939,-66.939,-66.939,-66.939,-67.939,-70.939,-70.939,-68.939,-65.939,-63.939003,-65.939,-66.939,-62.939003,-28.939003,4.060997,36.060997,33.060997,29.060997,30.060997,-16.939003,-78.939,-87.939,-90.939,-88.939,-90.939,-93.939,-94.939,-95.939,-96.939,-99.939,-100.939,-100.939,-99.939,-98.939,-98.939,-98.939,-98.939,-98.939,-97.939,-96.939,-97.939,-98.939,-99.939,-99.939,-97.939,-97.939,-98.939,-99.939,-99.939,-98.939,-98.939,-97.939,-95.939,-96.939,-96.939,-95.939,-94.939,-92.939,-91.939,-89.939,-87.939,-84.939,-82.939,-80.939,-78.939,-77.939,-75.939,-75.939,-75.939,-74.939,-72.939,-71.939,-68.939,-67.939,-68.939,-66.939,-62.939003,-60.939003,-53.939003,-41.939003,-16.939003,-1.939003,-9.939003,-9.939003,-8.939003,-34.939003,-49.939003,-51.939003,-52.939003,-52.939003,-51.939003,-50.939003,-49.939003,-42.939003,-44.939003,-54.939003,-54.939003,-54.939003,-54.939003,-56.939003,-60.939003,-61.939003,-62.939003,-63.939003,-63.939003,-64.939,-64.939,-66.939,-69.939,-71.939,-72.939,-71.939,-77.939,-81.939,-83.939,-86.939,-88.939,-86.939,-86.939,-88.939,-91.939,-92.939,-91.939,-92.939,-95.939,-95.939,-96.939,-100.939,-98.939,-97.939,-97.939,-98.939,-98.939,-98.939,-96.939,-92.939,-93.939,-94.939,-95.939,-95.939,-94.939,-88.939,-71.939,-40.939003,-26.939003,-14.939003,-4.939003,23.060997,54.060997,41.060997,33.060997,31.060997,33.060997,35.060997,36.060997,31.060997,25.060997,33.060997,29.060997,13.060997,0.06099701,-0.939003,25.060997,26.060997,16.060997,26.060997,26.060997,17.060997,14.060997,13.060997,12.060997,9.060997,6.060997,6.060997,4.060997,2.060997,1.060997,-0.939003,-1.939003,-2.939003,-3.939003,-4.939003,-8.939003,-14.939003,-17.939003,-16.939003,-10.939003,-10.939003,-12.939003,-13.939003,-15.939003,-17.939003,-2.939003,15.060997,29.060997,28.060997,18.060997,4.060997,-25.939003,-72.939,-62.939003,-52.939003,-53.939003,-53.939003,-51.939003,-53.939003,-57.939003,-62.939003,-60.939003,-59.939003,-62.939003,-66.939,-69.939,-68.939,-66.939,-65.939,-74.939,-80.939,-78.939,-79.939,-80.939,-80.939,-83.939,-86.939,-88.939,-89.939,-91.939,-91.939,-90.939,-91.939,-92.939,-94.939,-93.939,-94.939,-99.939,-94.939,-80.939,-27.939003,-6.939003,-16.939003,-2.939003,-2.939003,-38.939003,-60.939003,-76.939,-94.939,-101.939,-98.939,-93.939,-93.939,-100.939,-57.939003,4.060997,37.060997,51.060997,45.060997,15.060997,-24.939003,-82.939,-95.939,-88.939,-89.939,-89.939,-89.939,-88.939,-86.939,-82.939,-80.939,-79.939,-78.939,-77.939,-76.939,-72.939,-70.939,-74.939,-71.939,-66.939,-64.939,-60.939003,-57.939003,-57.939003,-57.939003,-57.939003,-55.939003,-53.939003,-58.939003,-56.939003,-47.939003,-54.939003,-58.939003,-54.939003,-51.939003,-50.939003,-57.939003,-57.939003,-48.939003,-50.939003,-53.939003,-56.939003,-51.939003,-45.939003,-51.939003,-52.939003,-48.939003,-50.939003,-47.939003,-33.939003,-22.939003,-14.939003,-15.939003,-27.939003,-50.939003,-57.939003,-60.939003,-60.939003,-57.939003,-55.939003,-57.939003,-61.939003,-68.939,-67.939,-66.939,-67.939,-70.939,-74.939,-77.939,-78.939,-78.939,-75.939,-74.939,-76.939,-79.939,-81.939,-82.939,-83.939,-85.939,-87.939,-88.939,-86.939,-87.939,-89.939,-89.939,-90.939,-89.939,-91.939,-93.939,-96.939,-96.939,-95.939,-99.939,-94.939,-78.939,-53.939003,-34.939003,-32.939003,-46.939003,-60.939003,-21.939003,12.060997,41.060997,21.060997,-1.939003,-17.939003,-26.939003,-31.939003,-26.939003,-23.939003,-21.939003,-13.939003,-7.939003,-7.939003,0.06099701,11.060997,29.060997,36.060997,31.060997,33.060997,36.060997,37.060997,42.060997,47.060997,37.060997,27.060997,17.060997,17.060997,16.060997,15.060997,16.060997,18.060997,19.060997,21.060997,22.060997,24.060997,25.060997,25.060997,25.060997,26.060997,27.060997,28.060997,28.060997,28.060997,29.060997,30.060997,32.060997,34.060997,35.060997,36.060997,34.060997,36.060997,37.060997,37.060997,34.060997,31.060997,34.060997,34.060997,32.060997,33.060997,35.060997,39.060997,41.060997,42.060997,45.060997,45.060997,41.060997,38.060997,37.060997,36.060997,35.060997,34.060997,28.060997,23.060997,20.060997,24.060997,27.060997,27.060997,27.060997,26.060997,25.060997,26.060997,31.060997,27.060997,24.060997,23.060997,22.060997,23.060997,18.060997,16.060997,18.060997,17.060997,15.060997,11.060997,9.060997,9.060997,6.060997,5.060997,4.060997,4.060997,4.060997,1.060997,-0.939003,-1.939003,-1.939003,-3.939003,-7.939003,-6.939003,-0.939003,23.060997,30.060997,30.060997,22.060997,16.060997,11.060997,3.060997,-11.939003,-36.939003,-40.939003,-35.939003,-55.939003,-74.939,-93.939,-91.939,-86.939,-83.939,-83.939,-84.939,-82.939,-60.939003,-17.939003,-9.939003,-11.939003,-28.939003,-52.939003,-76.939,-74.939,-72.939,-73.939,-70.939,-69.939,-69.939,-67.939,-65.939,-63.939003,-64.939,-65.939,-65.939,-64.939,-61.939003,-62.939003,-64.939,-65.939,-64.939,-60.939003,-62.939003,-63.939003,-61.939003,-62.939003,-62.939003,-60.939003,-57.939003,-53.939003,-52.939003,-51.939003,-52.939003,-52.939003,-50.939003,-52.939003,-51.939003,-47.939003,-34.939003,-25.939003,-29.939003,-35.939003,-39.939003,-30.939003,-26.939003,-26.939003,-21.939003,-20.939003,-29.939003,-31.939003,-30.939003,-19.939003,-11.939003,-5.939003,-8.939003,-10.939003,-6.939003,-8.939003,-12.939003,-13.939003,-10.939003,-1.939003,-2.939003,-3.939003,-1.939003,6.060997,13.060997,-17.939003,-40.939003,-58.939003,-51.939003,-45.939003,-47.939003,-46.939003,-43.939003,-44.939003,-44.939003,-45.939003,-45.939003,-42.939003,-39.939003,-33.939003,-28.939003,-31.939003,-33.939003,-35.939003,-35.939003,-34.939003,-32.939003,-40.939003,-49.939003,-18.939003,2.060997,12.060997,2.060997,-6.939003,-4.939003,-9.939003,-15.939003,-14.939003,-10.939003,-4.939003,-5.939003,-7.939003,-8.939003,-11.939003,-15.939003,-30.939003,-34.939003,-28.939003,-16.939003,-8.939003,-13.939003,-12.939003,-9.939003,-9.939003,-11.939003,-12.939003,-7.939003,-10.939003,-35.939003,-58.939003,-77.939,-76.939,-75.939,-76.939,-75.939,-73.939,-72.939,-72.939,-71.939,-72.939,-75.939,-78.939,-74.939,-73.939,-76.939,-71.939,-62.939003,-69.939,-75.939,-80.939,-79.939,-78.939,-79.939,-80.939,-80.939,-76.939,-75.939,-76.939,-73.939,-72.939,-72.939,-74.939,-76.939,-78.939,-80.939,-83.939,-81.939,-72.939,-42.939003,-28.939003,-22.939003,-35.939003,-54.939003,-77.939,-78.939,-77.939,-78.939,-79.939,-79.939,-81.939,-81.939,-80.939,-78.939,-77.939,-80.939,-81.939,-83.939,-81.939,-80.939,-80.939,-79.939,-79.939,-80.939,-81.939,-81.939,-81.939,-82.939,-85.939,-81.939,-80.939,-82.939,-83.939,-84.939,-84.939,-84.939,-83.939,-81.939,-81.939,-83.939,-86.939,-86.939,-70.939,-48.939003,-20.939003,-10.939003,-16.939003,-55.939003,-76.939,-88.939,-84.939,-84.939,-90.939,-89.939,-88.939,-90.939,-90.939,-90.939,-90.939,-89.939,-88.939,-88.939,-89.939,-90.939,-89.939,-88.939,-86.939,-86.939,-88.939,-93.939,-96.939,-92.939,-92.939,-92.939,-82.939,-75.939,-71.939,-85.939,-96.939,-93.939,-93.939,-94.939,-95.939,-96.939,-97.939,-88.939,-80.939,49.060997,50.060997,51.060997,50.060997,48.060997,49.060997,53.060997,47.060997,13.060997,-39.939003,-94.939,-88.939,-84.939,-79.939,-81.939,-81.939,-76.939,-77.939,-80.939,-78.939,-75.939,-72.939,-57.939003,-45.939003,-41.939003,-10.939003,27.060997,19.060997,18.060997,20.060997,19.060997,19.060997,23.060997,21.060997,22.060997,60.060997,83.061,89.061,45.060997,10.060997,13.060997,10.060997,2.060997,-4.939003,-10.939003,-15.939003,-13.939003,-15.939003,-28.939003,-57.939003,-89.939,-98.939,-103.939,-103.939,-101.939,-99.939,-98.939,-97.939,-96.939,-88.939,-82.939,-78.939,-84.939,-90.939,-94.939,-97.939,-100.939,-101.939,-100.939,-98.939,-96.939,-95.939,-93.939,-96.939,-95.939,-34.939003,20.060997,68.061,59.060997,49.060997,54.060997,-2.939003,-79.939,-95.939,-101.939,-97.939,-100.939,-102.939,-101.939,-100.939,-99.939,-101.939,-102.939,-102.939,-100.939,-100.939,-100.939,-101.939,-101.939,-101.939,-101.939,-99.939,-100.939,-100.939,-100.939,-98.939,-96.939,-98.939,-99.939,-98.939,-96.939,-94.939,-95.939,-94.939,-92.939,-91.939,-89.939,-84.939,-79.939,-75.939,-75.939,-70.939,-64.939,-60.939003,-56.939003,-49.939003,-47.939003,-44.939003,-39.939003,-40.939003,-41.939003,-38.939003,-34.939003,-32.939003,-27.939003,-25.939003,-27.939003,-28.939003,-27.939003,-17.939003,-15.939003,-21.939003,-32.939003,-39.939003,-35.939003,-31.939003,-26.939003,-34.939003,-39.939003,-42.939003,-50.939003,-57.939003,-57.939003,-60.939003,-63.939003,-64.939,-70.939,-83.939,-83.939,-84.939,-85.939,-90.939,-95.939,-99.939,-101.939,-102.939,-100.939,-99.939,-102.939,-103.939,-103.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-102.939,-103.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-99.939,-95.939,-97.939,-97.939,-96.939,-96.939,-96.939,-89.939,-71.939,-38.939003,-29.939003,-20.939003,-8.939003,17.060997,45.060997,34.060997,25.060997,17.060997,18.060997,19.060997,17.060997,9.060997,-0.939003,0.06099701,-4.939003,-14.939003,-16.939003,-15.939003,-8.939003,-11.939003,-19.939003,-24.939003,-28.939003,-31.939003,-32.939003,-33.939003,-35.939003,-38.939003,-41.939003,-40.939003,-41.939003,-42.939003,-42.939003,-42.939003,-42.939003,-36.939003,-29.939003,-30.939003,-28.939003,-25.939003,-23.939003,-22.939003,-20.939003,-16.939003,-9.939003,-4.939003,-6.939003,-14.939003,-1.939003,13.060997,26.060997,26.060997,19.060997,9.060997,-27.939003,-92.939,-93.939,-91.939,-92.939,-93.939,-95.939,-97.939,-100.939,-103.939,-102.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-97.939,-85.939,-29.939003,-6.939003,-17.939003,-3.939003,-1.939003,-36.939003,-59.939003,-75.939,-91.939,-99.939,-98.939,-93.939,-91.939,-96.939,-58.939003,-4.939003,29.060997,40.060997,26.060997,1.060997,-30.939003,-77.939,-80.939,-67.939,-69.939,-69.939,-68.939,-66.939,-63.939003,-55.939003,-50.939003,-48.939003,-47.939003,-45.939003,-42.939003,-41.939003,-40.939003,-42.939003,-36.939003,-27.939003,-33.939003,-36.939003,-35.939003,-30.939003,-27.939003,-30.939003,-31.939003,-31.939003,-36.939003,-37.939003,-33.939003,-44.939003,-49.939003,-42.939003,-38.939003,-38.939003,-47.939003,-50.939003,-45.939003,-50.939003,-56.939003,-64.939,-61.939003,-54.939003,-61.939003,-64.939,-63.939003,-69.939,-62.939003,-19.939003,12.060997,37.060997,40.060997,6.060997,-61.939003,-82.939,-94.939,-95.939,-95.939,-96.939,-98.939,-100.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-100.939,-101.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-99.939,-97.939,-99.939,-100.939,-100.939,-99.939,-99.939,-99.939,-91.939,-75.939,-51.939003,-33.939003,-32.939003,-46.939003,-61.939003,-25.939003,4.060997,30.060997,7.060997,-15.939003,-27.939003,-28.939003,-24.939003,-22.939003,-19.939003,-16.939003,-13.939003,-9.939003,-4.939003,6.060997,19.060997,32.060997,35.060997,26.060997,34.060997,40.060997,39.060997,44.060997,48.060997,29.060997,19.060997,18.060997,21.060997,21.060997,14.060997,13.060997,16.060997,22.060997,24.060997,22.060997,24.060997,25.060997,24.060997,24.060997,25.060997,26.060997,28.060997,30.060997,30.060997,30.060997,32.060997,33.060997,34.060997,37.060997,37.060997,32.060997,33.060997,34.060997,36.060997,35.060997,33.060997,36.060997,35.060997,30.060997,31.060997,32.060997,36.060997,40.060997,43.060997,45.060997,44.060997,39.060997,35.060997,33.060997,35.060997,35.060997,34.060997,27.060997,22.060997,21.060997,24.060997,27.060997,25.060997,26.060997,27.060997,26.060997,28.060997,33.060997,28.060997,23.060997,22.060997,22.060997,24.060997,19.060997,17.060997,17.060997,16.060997,15.060997,12.060997,11.060997,11.060997,8.060997,6.060997,3.060997,6.060997,7.060997,2.060997,-0.939003,-1.939003,-3.939003,-5.939003,-6.939003,-7.939003,-3.939003,12.060997,25.060997,35.060997,25.060997,16.060997,9.060997,9.060997,1.060997,-28.939003,-34.939003,-31.939003,-43.939003,-60.939003,-82.939,-74.939,-65.939,-59.939003,-58.939003,-59.939003,-56.939003,-51.939003,-45.939003,-51.939003,-55.939003,-50.939003,-46.939003,-45.939003,-44.939003,-44.939003,-44.939003,-43.939003,-44.939003,-45.939003,-46.939003,-46.939003,-46.939003,-47.939003,-49.939003,-52.939003,-55.939003,-55.939003,-58.939003,-61.939003,-63.939003,-64.939,-64.939,-66.939,-69.939,-71.939,-75.939,-78.939,-72.939,-62.939003,-48.939003,-48.939003,-50.939003,-52.939003,-51.939003,-48.939003,-53.939003,-52.939003,-45.939003,-12.939003,9.060997,-4.939003,-20.939003,-32.939003,-10.939003,0.06099701,0.06099701,7.060997,6.060997,-17.939003,-19.939003,-12.939003,4.060997,15.060997,21.060997,13.060997,6.060997,8.060997,6.060997,2.060997,-5.939003,-3.939003,10.060997,1.060997,-4.939003,-3.939003,2.060997,6.060997,-22.939003,-42.939003,-56.939003,-51.939003,-46.939003,-44.939003,-41.939003,-38.939003,-42.939003,-46.939003,-50.939003,-48.939003,-44.939003,-38.939003,-31.939003,-24.939003,-26.939003,-28.939003,-31.939003,-33.939003,-33.939003,-30.939003,-42.939003,-56.939003,-4.939003,29.060997,47.060997,31.060997,21.060997,35.060997,23.060997,3.060997,7.060997,17.060997,32.060997,35.060997,33.060997,28.060997,24.060997,20.060997,-9.939003,-21.939003,-13.939003,15.060997,37.060997,31.060997,30.060997,32.060997,32.060997,30.060997,26.060997,35.060997,28.060997,-21.939003,-65.939,-101.939,-94.939,-92.939,-95.939,-94.939,-93.939,-93.939,-90.939,-86.939,-89.939,-92.939,-95.939,-90.939,-88.939,-92.939,-79.939,-60.939003,-77.939,-87.939,-91.939,-94.939,-94.939,-90.939,-88.939,-87.939,-84.939,-83.939,-84.939,-81.939,-80.939,-78.939,-78.939,-77.939,-79.939,-81.939,-86.939,-83.939,-71.939,-37.939003,-25.939003,-24.939003,-38.939003,-55.939003,-75.939,-74.939,-71.939,-73.939,-71.939,-68.939,-73.939,-73.939,-68.939,-65.939,-65.939,-67.939,-69.939,-71.939,-67.939,-64.939,-62.939003,-61.939003,-62.939003,-65.939,-64.939,-62.939003,-61.939003,-63.939003,-68.939,-66.939,-65.939,-65.939,-65.939,-64.939,-65.939,-65.939,-63.939003,-61.939003,-59.939003,-64.939,-67.939,-68.939,-60.939003,-53.939003,-49.939003,-42.939003,-40.939003,-55.939003,-63.939003,-70.939,-66.939,-67.939,-73.939,-73.939,-72.939,-74.939,-74.939,-73.939,-74.939,-75.939,-78.939,-79.939,-80.939,-79.939,-79.939,-79.939,-78.939,-79.939,-81.939,-85.939,-89.939,-87.939,-88.939,-88.939,-77.939,-72.939,-73.939,-86.939,-95.939,-91.939,-92.939,-95.939,-98.939,-100.939,-101.939,-90.939,-80.939,39.060997,37.060997,34.060997,27.060997,21.060997,16.060997,14.060997,6.060997,-14.939003,-37.939003,-58.939003,-53.939003,-51.939003,-50.939003,-52.939003,-51.939003,-45.939003,-44.939003,-46.939003,-43.939003,-44.939003,-49.939003,-47.939003,-46.939003,-44.939003,-17.939003,16.060997,20.060997,23.060997,21.060997,18.060997,18.060997,24.060997,26.060997,28.060997,63.060997,71.061,50.060997,30.060997,15.060997,12.060997,5.060997,-2.939003,-8.939003,-12.939003,-14.939003,-13.939003,-17.939003,-31.939003,-63.939003,-98.939,-100.939,-101.939,-101.939,-101.939,-100.939,-97.939,-94.939,-92.939,-93.939,-94.939,-95.939,-97.939,-98.939,-97.939,-98.939,-99.939,-100.939,-100.939,-97.939,-97.939,-96.939,-94.939,-95.939,-90.939,-40.939003,15.060997,75.061,64.061,54.060997,71.061,12.060997,-72.939,-92.939,-100.939,-96.939,-99.939,-102.939,-100.939,-99.939,-100.939,-99.939,-98.939,-97.939,-92.939,-88.939,-85.939,-82.939,-79.939,-78.939,-76.939,-74.939,-72.939,-69.939,-67.939,-63.939003,-60.939003,-59.939003,-56.939003,-53.939003,-49.939003,-47.939003,-50.939003,-51.939003,-51.939003,-48.939003,-44.939003,-39.939003,-37.939003,-38.939003,-46.939003,-43.939003,-35.939003,-39.939003,-42.939003,-44.939003,-47.939003,-48.939003,-44.939003,-47.939003,-52.939003,-54.939003,-56.939003,-58.939003,-60.939003,-62.939003,-57.939003,-62.939003,-69.939,-69.939,-67.939,-64.939,-5.939003,38.060997,32.060997,40.060997,50.060997,-1.939003,-44.939003,-79.939,-85.939,-88.939,-88.939,-89.939,-89.939,-90.939,-92.939,-96.939,-96.939,-96.939,-97.939,-97.939,-96.939,-100.939,-101.939,-101.939,-99.939,-98.939,-100.939,-102.939,-103.939,-102.939,-101.939,-100.939,-101.939,-102.939,-102.939,-103.939,-102.939,-103.939,-102.939,-99.939,-100.939,-100.939,-100.939,-99.939,-98.939,-97.939,-94.939,-91.939,-86.939,-82.939,-79.939,-78.939,-76.939,-71.939,-66.939,-60.939003,-58.939003,-56.939003,-55.939003,-51.939003,-47.939003,-44.939003,-39.939003,-32.939003,-30.939003,-28.939003,-23.939003,-20.939003,-19.939003,-21.939003,-23.939003,-23.939003,-24.939003,-24.939003,-23.939003,-23.939003,-22.939003,-17.939003,-15.939003,-14.939003,-15.939003,-15.939003,-7.939003,-2.939003,1.060997,4.060997,6.060997,8.060997,8.060997,9.060997,15.060997,14.060997,10.060997,14.060997,15.060997,17.060997,18.060997,18.060997,19.060997,22.060997,26.060997,29.060997,23.060997,7.060997,12.060997,21.060997,36.060997,34.060997,28.060997,34.060997,24.060997,-1.939003,-1.939003,6.060997,24.060997,23.060997,14.060997,6.060997,-28.939003,-89.939,-97.939,-99.939,-98.939,-99.939,-100.939,-101.939,-101.939,-102.939,-102.939,-102.939,-103.939,-102.939,-100.939,-102.939,-103.939,-103.939,-102.939,-102.939,-103.939,-102.939,-100.939,-100.939,-99.939,-97.939,-96.939,-95.939,-97.939,-95.939,-92.939,-89.939,-85.939,-80.939,-79.939,-79.939,-78.939,-74.939,-65.939,-24.939003,-5.939003,-9.939003,-0.939003,-4.939003,-41.939003,-66.939,-84.939,-74.939,-63.939003,-50.939003,-50.939003,-50.939003,-46.939003,-39.939003,-31.939003,-26.939003,-22.939003,-22.939003,-28.939003,-35.939003,-43.939003,-45.939003,-44.939003,-48.939003,-48.939003,-47.939003,-50.939003,-52.939003,-50.939003,-50.939003,-51.939003,-51.939003,-52.939003,-54.939003,-57.939003,-61.939003,-65.939,-63.939003,-58.939003,-58.939003,-61.939003,-66.939,-66.939,-66.939,-70.939,-72.939,-74.939,-75.939,-66.939,-46.939003,-67.939,-82.939,-79.939,-73.939,-69.939,-80.939,-85.939,-84.939,-85.939,-87.939,-90.939,-89.939,-87.939,-88.939,-89.939,-90.939,-92.939,-70.939,7.060997,43.060997,60.060997,66.061,26.060997,-56.939003,-84.939,-100.939,-100.939,-100.939,-101.939,-101.939,-101.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-102.939,-101.939,-100.939,-100.939,-99.939,-95.939,-96.939,-97.939,-93.939,-91.939,-91.939,-91.939,-90.939,-89.939,-85.939,-80.939,-76.939,-74.939,-72.939,-72.939,-70.939,-66.939,-63.939003,-62.939003,-59.939003,-54.939003,-48.939003,-42.939003,-38.939003,-38.939003,-41.939003,-42.939003,-39.939003,-29.939003,-10.939003,-11.939003,-19.939003,-38.939003,-33.939003,-19.939003,-15.939003,-11.939003,-6.939003,-6.939003,-4.939003,-2.939003,8.060997,21.060997,30.060997,32.060997,28.060997,34.060997,40.060997,43.060997,44.060997,42.060997,23.060997,14.060997,17.060997,21.060997,22.060997,17.060997,16.060997,17.060997,23.060997,24.060997,23.060997,25.060997,27.060997,26.060997,25.060997,26.060997,27.060997,29.060997,31.060997,31.060997,31.060997,32.060997,33.060997,35.060997,38.060997,37.060997,33.060997,34.060997,35.060997,37.060997,36.060997,34.060997,37.060997,36.060997,33.060997,34.060997,35.060997,37.060997,42.060997,46.060997,48.060997,46.060997,41.060997,36.060997,33.060997,36.060997,37.060997,37.060997,30.060997,25.060997,23.060997,28.060997,31.060997,28.060997,28.060997,29.060997,27.060997,28.060997,32.060997,26.060997,22.060997,22.060997,23.060997,25.060997,21.060997,18.060997,17.060997,17.060997,16.060997,13.060997,12.060997,12.060997,10.060997,7.060997,3.060997,3.060997,3.060997,1.060997,1.060997,3.060997,-0.939003,-2.939003,-4.939003,-8.939003,-8.939003,-2.939003,14.060997,32.060997,27.060997,20.060997,10.060997,13.060997,8.060997,-19.939003,-33.939003,-40.939003,-40.939003,-50.939003,-69.939,-66.939,-61.939003,-56.939003,-55.939003,-58.939003,-61.939003,-52.939003,-29.939003,-26.939003,-26.939003,-30.939003,-51.939003,-77.939,-77.939,-77.939,-78.939,-78.939,-79.939,-81.939,-83.939,-84.939,-84.939,-84.939,-84.939,-86.939,-87.939,-86.939,-87.939,-87.939,-89.939,-90.939,-90.939,-91.939,-92.939,-92.939,-94.939,-94.939,-83.939,-64.939,-38.939003,-41.939003,-46.939003,-49.939003,-48.939003,-45.939003,-54.939003,-56.939003,-52.939003,-17.939003,4.060997,-13.939003,-26.939003,-34.939003,-21.939003,-16.939003,-21.939003,-16.939003,-15.939003,-28.939003,-34.939003,-35.939003,-25.939003,-23.939003,-27.939003,-29.939003,-31.939003,-31.939003,-35.939003,-39.939003,-38.939003,-39.939003,-41.939003,-43.939003,-43.939003,-40.939003,-38.939003,-38.939003,-40.939003,-43.939003,-47.939003,-43.939003,-39.939003,-37.939003,-34.939003,-33.939003,-37.939003,-42.939003,-46.939003,-46.939003,-45.939003,-41.939003,-32.939003,-21.939003,-24.939003,-27.939003,-32.939003,-32.939003,-31.939003,-31.939003,-37.939003,-44.939003,-19.939003,-6.939003,-3.939003,-9.939003,-12.939003,-2.939003,-11.939003,-26.939003,-25.939003,-17.939003,-2.939003,0.06099701,-0.939003,-4.939003,-0.939003,3.060997,-22.939003,-30.939003,-17.939003,0.06099701,12.060997,7.060997,7.060997,9.060997,11.060997,13.060997,16.060997,26.060997,18.060997,-30.939003,-67.939,-95.939,-89.939,-88.939,-92.939,-92.939,-93.939,-95.939,-93.939,-91.939,-93.939,-96.939,-100.939,-97.939,-95.939,-97.939,-84.939,-68.939,-85.939,-96.939,-99.939,-99.939,-99.939,-99.939,-98.939,-98.939,-96.939,-96.939,-97.939,-96.939,-95.939,-95.939,-95.939,-94.939,-95.939,-96.939,-97.939,-92.939,-70.939,-13.939003,17.060997,29.060997,-14.939003,-54.939003,-90.939,-93.939,-92.939,-93.939,-92.939,-91.939,-93.939,-93.939,-91.939,-90.939,-90.939,-91.939,-92.939,-92.939,-89.939,-87.939,-86.939,-85.939,-85.939,-88.939,-87.939,-85.939,-84.939,-84.939,-87.939,-85.939,-85.939,-85.939,-84.939,-83.939,-82.939,-81.939,-79.939,-76.939,-74.939,-75.939,-78.939,-78.939,-63.939003,-46.939003,-29.939003,-27.939003,-34.939003,-56.939003,-68.939,-76.939,-71.939,-71.939,-74.939,-74.939,-73.939,-75.939,-73.939,-70.939,-69.939,-70.939,-75.939,-74.939,-73.939,-72.939,-72.939,-72.939,-71.939,-72.939,-73.939,-75.939,-76.939,-74.939,-74.939,-73.939,-70.939,-67.939,-64.939,-72.939,-78.939,-78.939,-76.939,-74.939,-74.939,-77.939,-82.939,-75.939,-69.939,7.060997,6.060997,4.060997,-0.939003,-5.939003,-9.939003,-11.939003,-14.939003,-24.939003,-33.939003,-41.939003,-42.939003,-44.939003,-45.939003,-45.939003,-43.939003,-40.939003,-41.939003,-44.939003,-43.939003,-45.939003,-48.939003,-45.939003,-43.939003,-43.939003,-22.939003,4.060997,19.060997,23.060997,19.060997,17.060997,18.060997,22.060997,27.060997,34.060997,60.060997,57.060997,25.060997,19.060997,14.060997,8.060997,0.06099701,-6.939003,-10.939003,-13.939003,-15.939003,-12.939003,-18.939003,-41.939003,-72.939,-102.939,-100.939,-100.939,-101.939,-102.939,-101.939,-98.939,-95.939,-93.939,-97.939,-100.939,-101.939,-102.939,-102.939,-99.939,-98.939,-98.939,-97.939,-96.939,-93.939,-91.939,-89.939,-87.939,-86.939,-80.939,-44.939003,0.06099701,54.060997,41.060997,31.060997,44.060997,1.060997,-62.939003,-74.939,-77.939,-72.939,-76.939,-79.939,-74.939,-73.939,-74.939,-72.939,-70.939,-69.939,-66.939,-63.939003,-64.939,-62.939003,-59.939003,-57.939003,-55.939003,-54.939003,-53.939003,-52.939003,-51.939003,-50.939003,-48.939003,-48.939003,-46.939003,-42.939003,-40.939003,-40.939003,-46.939003,-48.939003,-49.939003,-48.939003,-45.939003,-44.939003,-43.939003,-46.939003,-54.939003,-52.939003,-46.939003,-51.939003,-56.939003,-60.939003,-64.939,-67.939,-64.939,-67.939,-71.939,-75.939,-78.939,-81.939,-85.939,-87.939,-82.939,-85.939,-91.939,-95.939,-94.939,-90.939,3.060997,73.061,64.061,74.061,86.061,17.060997,-43.939003,-98.939,-103.939,-103.939,-103.939,-103.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-98.939,-94.939,-97.939,-97.939,-94.939,-92.939,-90.939,-89.939,-89.939,-90.939,-87.939,-84.939,-83.939,-81.939,-79.939,-79.939,-79.939,-79.939,-79.939,-78.939,-76.939,-75.939,-73.939,-70.939,-70.939,-71.939,-71.939,-68.939,-63.939003,-62.939003,-61.939003,-59.939003,-58.939003,-56.939003,-53.939003,-49.939003,-44.939003,-43.939003,-41.939003,-41.939003,-38.939003,-36.939003,-35.939003,-33.939003,-31.939003,-30.939003,-28.939003,-22.939003,-20.939003,-19.939003,-18.939003,-16.939003,-13.939003,-14.939003,-14.939003,-13.939003,-10.939003,-5.939003,1.060997,5.060997,6.060997,-5.939003,-8.939003,13.060997,18.060997,19.060997,32.060997,38.060997,39.060997,36.060997,36.060997,43.060997,44.060997,42.060997,45.060997,47.060997,49.060997,51.060997,51.060997,51.060997,51.060997,53.060997,58.060997,48.060997,23.060997,28.060997,41.060997,62.060997,59.060997,48.060997,55.060997,40.060997,2.060997,-2.939003,2.060997,24.060997,23.060997,13.060997,5.060997,-27.939003,-85.939,-95.939,-97.939,-91.939,-91.939,-92.939,-91.939,-89.939,-87.939,-87.939,-87.939,-87.939,-86.939,-84.939,-84.939,-84.939,-83.939,-82.939,-82.939,-82.939,-81.939,-79.939,-80.939,-77.939,-72.939,-71.939,-72.939,-77.939,-73.939,-66.939,-64.939,-63.939003,-61.939003,-59.939003,-58.939003,-59.939003,-58.939003,-54.939003,-20.939003,-3.939003,-2.939003,4.060997,-0.939003,-41.939003,-64.939,-79.939,-71.939,-60.939003,-45.939003,-43.939003,-44.939003,-44.939003,-40.939003,-33.939003,-23.939003,-16.939003,-13.939003,-20.939003,-28.939003,-43.939003,-51.939003,-55.939003,-57.939003,-58.939003,-58.939003,-61.939003,-64.939,-65.939,-67.939,-68.939,-69.939,-70.939,-73.939,-77.939,-80.939,-83.939,-82.939,-79.939,-78.939,-81.939,-85.939,-87.939,-88.939,-91.939,-94.939,-96.939,-97.939,-83.939,-54.939003,-78.939,-98.939,-96.939,-89.939,-82.939,-96.939,-102.939,-101.939,-100.939,-99.939,-98.939,-97.939,-97.939,-96.939,-95.939,-94.939,-94.939,-71.939,9.060997,40.060997,51.060997,54.060997,19.060997,-52.939003,-75.939,-87.939,-85.939,-85.939,-85.939,-82.939,-81.939,-81.939,-81.939,-80.939,-80.939,-79.939,-79.939,-80.939,-80.939,-77.939,-75.939,-72.939,-69.939,-71.939,-74.939,-69.939,-67.939,-69.939,-67.939,-65.939,-65.939,-62.939003,-57.939003,-54.939003,-53.939003,-52.939003,-53.939003,-52.939003,-49.939003,-47.939003,-46.939003,-45.939003,-43.939003,-41.939003,-41.939003,-42.939003,-42.939003,-41.939003,-39.939003,-42.939003,-35.939003,-16.939003,-20.939003,-28.939003,-38.939003,-32.939003,-19.939003,-15.939003,-10.939003,-4.939003,-3.939003,-2.939003,1.060997,12.060997,25.060997,30.060997,31.060997,31.060997,34.060997,40.060997,47.060997,43.060997,34.060997,20.060997,14.060997,18.060997,21.060997,23.060997,20.060997,20.060997,20.060997,23.060997,24.060997,24.060997,27.060997,28.060997,27.060997,27.060997,27.060997,29.060997,30.060997,31.060997,31.060997,31.060997,32.060997,33.060997,35.060997,36.060997,36.060997,34.060997,36.060997,37.060997,37.060997,36.060997,34.060997,37.060997,37.060997,35.060997,36.060997,38.060997,39.060997,43.060997,47.060997,48.060997,47.060997,42.060997,37.060997,34.060997,35.060997,37.060997,38.060997,32.060997,28.060997,24.060997,30.060997,33.060997,29.060997,29.060997,30.060997,28.060997,28.060997,30.060997,25.060997,22.060997,24.060997,24.060997,24.060997,22.060997,20.060997,18.060997,17.060997,16.060997,14.060997,13.060997,12.060997,11.060997,8.060997,4.060997,3.060997,3.060997,1.060997,3.060997,5.060997,0.06099701,-1.939003,-2.939003,-6.939003,-9.939003,-11.939003,4.060997,24.060997,28.060997,26.060997,17.060997,15.060997,9.060997,-7.939003,-26.939003,-43.939003,-39.939003,-44.939003,-57.939003,-65.939,-70.939,-69.939,-69.939,-71.939,-76.939,-52.939003,-2.939003,4.060997,0.06099701,-17.939003,-55.939003,-96.939,-95.939,-96.939,-97.939,-98.939,-98.939,-98.939,-98.939,-100.939,-99.939,-97.939,-95.939,-96.939,-95.939,-94.939,-93.939,-91.939,-90.939,-89.939,-89.939,-90.939,-90.939,-88.939,-87.939,-85.939,-76.939,-61.939003,-39.939003,-41.939003,-46.939003,-51.939003,-49.939003,-46.939003,-52.939003,-56.939003,-56.939003,-34.939003,-19.939003,-31.939003,-38.939003,-41.939003,-36.939003,-36.939003,-40.939003,-36.939003,-34.939003,-38.939003,-42.939003,-45.939003,-40.939003,-40.939003,-47.939003,-47.939003,-46.939003,-45.939003,-47.939003,-47.939003,-46.939003,-46.939003,-50.939003,-49.939003,-47.939003,-41.939003,-39.939003,-38.939003,-42.939003,-46.939003,-51.939003,-49.939003,-45.939003,-42.939003,-39.939003,-37.939003,-39.939003,-42.939003,-44.939003,-45.939003,-45.939003,-42.939003,-32.939003,-21.939003,-24.939003,-27.939003,-30.939003,-30.939003,-30.939003,-30.939003,-34.939003,-38.939003,-27.939003,-24.939003,-29.939003,-31.939003,-32.939003,-28.939003,-34.939003,-42.939003,-43.939003,-38.939003,-27.939003,-25.939003,-26.939003,-29.939003,-24.939003,-17.939003,-34.939003,-38.939003,-29.939003,-23.939003,-19.939003,-22.939003,-21.939003,-19.939003,-17.939003,-14.939003,-8.939003,-0.939003,-4.939003,-36.939003,-59.939003,-75.939,-74.939,-73.939,-75.939,-76.939,-78.939,-82.939,-80.939,-78.939,-80.939,-81.939,-84.939,-84.939,-84.939,-84.939,-75.939,-65.939,-78.939,-86.939,-89.939,-89.939,-89.939,-90.939,-90.939,-90.939,-89.939,-88.939,-89.939,-90.939,-90.939,-90.939,-90.939,-90.939,-91.939,-92.939,-93.939,-88.939,-68.939,-12.939003,20.060997,35.060997,-13.939003,-55.939003,-91.939,-95.939,-95.939,-95.939,-95.939,-95.939,-96.939,-97.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-96.939,-96.939,-96.939,-95.939,-95.939,-97.939,-97.939,-96.939,-96.939,-95.939,-96.939,-96.939,-96.939,-97.939,-96.939,-95.939,-94.939,-92.939,-91.939,-88.939,-86.939,-85.939,-88.939,-88.939,-59.939003,-31.939003,-3.939003,-8.939003,-27.939003,-63.939003,-80.939,-86.939,-83.939,-82.939,-84.939,-83.939,-83.939,-84.939,-82.939,-80.939,-78.939,-79.939,-82.939,-81.939,-79.939,-79.939,-79.939,-80.939,-79.939,-79.939,-80.939,-80.939,-80.939,-79.939,-78.939,-77.939,-73.939,-70.939,-67.939,-74.939,-80.939,-81.939,-79.939,-75.939,-74.939,-77.939,-82.939,-72.939,-64.939,-47.939003,-43.939003,-39.939003,-35.939003,-30.939003,-28.939003,-22.939003,-17.939003,-16.939003,-28.939003,-44.939003,-56.939003,-63.939003,-64.939,-60.939003,-58.939003,-63.939003,-67.939,-73.939,-78.939,-78.939,-70.939,-51.939003,-37.939003,-36.939003,-24.939003,-8.939003,14.060997,20.060997,12.060997,16.060997,19.060997,17.060997,26.060997,40.060997,49.060997,41.060997,14.060997,10.060997,8.060997,3.060997,-3.939003,-8.939003,-9.939003,-12.939003,-18.939003,-11.939003,-17.939003,-56.939003,-83.939,-100.939,-98.939,-99.939,-102.939,-103.939,-103.939,-101.939,-100.939,-100.939,-100.939,-99.939,-98.939,-101.939,-103.939,-101.939,-99.939,-96.939,-92.939,-89.939,-85.939,-77.939,-72.939,-72.939,-69.939,-64.939,-46.939003,-23.939003,4.060997,-7.939003,-20.939003,-26.939003,-37.939003,-49.939003,-39.939003,-32.939003,-27.939003,-31.939003,-33.939003,-24.939003,-22.939003,-21.939003,-19.939003,-19.939003,-20.939003,-21.939003,-26.939003,-37.939003,-40.939003,-40.939003,-38.939003,-37.939003,-38.939003,-44.939003,-49.939003,-54.939003,-57.939003,-60.939003,-66.939,-69.939,-67.939,-70.939,-75.939,-81.939,-84.939,-85.939,-89.939,-92.939,-97.939,-98.939,-98.939,-99.939,-98.939,-98.939,-98.939,-99.939,-99.939,-99.939,-100.939,-99.939,-100.939,-100.939,-100.939,-101.939,-100.939,-101.939,-102.939,-101.939,-97.939,-92.939,-95.939,-97.939,-99.939,-5.939003,65.061,59.060997,69.061,81.061,23.060997,-35.939003,-97.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-99.939,-93.939,-87.939,-92.939,-89.939,-79.939,-77.939,-75.939,-70.939,-66.939,-63.939003,-56.939003,-51.939003,-50.939003,-43.939003,-36.939003,-32.939003,-31.939003,-30.939003,-31.939003,-32.939003,-34.939003,-28.939003,-21.939003,-15.939003,-16.939003,-20.939003,-25.939003,-24.939003,-19.939003,-28.939003,-36.939003,-41.939003,-42.939003,-41.939003,-46.939003,-48.939003,-49.939003,-50.939003,-52.939003,-54.939003,-57.939003,-61.939003,-62.939003,-54.939003,-36.939003,-28.939003,-19.939003,-5.939003,18.060997,44.060997,43.060997,45.060997,48.060997,50.060997,50.060997,47.060997,46.060997,49.060997,57.060997,57.060997,47.060997,15.060997,6.060997,54.060997,52.060997,34.060997,59.060997,69.061,61.060997,51.060997,45.060997,50.060997,52.060997,54.060997,52.060997,52.060997,53.060997,56.060997,57.060997,52.060997,50.060997,50.060997,56.060997,47.060997,22.060997,24.060997,36.060997,57.060997,56.060997,49.060997,56.060997,39.060997,-0.939003,-5.939003,1.060997,25.060997,25.060997,14.060997,8.060997,-23.939003,-80.939,-88.939,-84.939,-72.939,-68.939,-69.939,-66.939,-63.939003,-59.939003,-57.939003,-56.939003,-56.939003,-56.939003,-56.939003,-50.939003,-47.939003,-44.939003,-43.939003,-42.939003,-41.939003,-41.939003,-40.939003,-41.939003,-37.939003,-27.939003,-27.939003,-32.939003,-43.939003,-37.939003,-24.939003,-29.939003,-34.939003,-42.939003,-40.939003,-38.939003,-43.939003,-49.939003,-52.939003,-17.939003,0.06099701,2.060997,12.060997,8.060997,-34.939003,-51.939003,-60.939003,-82.939,-89.939,-82.939,-74.939,-73.939,-91.939,-61.939003,-10.939003,39.060997,59.060997,51.060997,26.060997,-11.939003,-78.939,-98.939,-99.939,-98.939,-98.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-99.939,-98.939,-94.939,-91.939,-93.939,-94.939,-92.939,-94.939,-94.939,-94.939,-96.939,-99.939,-101.939,-87.939,-56.939003,-78.939,-97.939,-94.939,-84.939,-75.939,-93.939,-101.939,-96.939,-95.939,-92.939,-88.939,-86.939,-85.939,-84.939,-80.939,-75.939,-77.939,-64.939,-15.939003,3.060997,8.060997,4.060997,-15.939003,-48.939003,-55.939003,-55.939003,-50.939003,-49.939003,-48.939003,-43.939003,-40.939003,-38.939003,-36.939003,-35.939003,-34.939003,-34.939003,-34.939003,-39.939003,-39.939003,-34.939003,-28.939003,-23.939003,-24.939003,-28.939003,-33.939003,-31.939003,-32.939003,-34.939003,-30.939003,-28.939003,-29.939003,-31.939003,-32.939003,-34.939003,-36.939003,-38.939003,-42.939003,-46.939003,-50.939003,-52.939003,-51.939003,-56.939003,-58.939003,-56.939003,-49.939003,-45.939003,-43.939003,-47.939003,-51.939003,-34.939003,-13.939003,12.060997,-18.939003,-42.939003,-29.939003,-25.939003,-24.939003,-22.939003,-17.939003,-9.939003,-6.939003,-1.939003,7.060997,19.060997,30.060997,31.060997,32.060997,34.060997,35.060997,40.060997,51.060997,41.060997,23.060997,19.060997,18.060997,20.060997,21.060997,23.060997,24.060997,25.060997,24.060997,23.060997,24.060997,27.060997,28.060997,28.060997,27.060997,28.060997,29.060997,30.060997,31.060997,31.060997,30.060997,30.060997,31.060997,33.060997,35.060997,33.060997,33.060997,34.060997,37.060997,38.060997,36.060997,33.060997,31.060997,34.060997,36.060997,37.060997,38.060997,39.060997,40.060997,43.060997,45.060997,47.060997,46.060997,43.060997,39.060997,35.060997,34.060997,35.060997,37.060997,33.060997,29.060997,24.060997,29.060997,33.060997,29.060997,28.060997,29.060997,28.060997,28.060997,28.060997,26.060997,25.060997,27.060997,26.060997,23.060997,23.060997,22.060997,19.060997,18.060997,17.060997,14.060997,12.060997,11.060997,11.060997,9.060997,5.060997,6.060997,6.060997,4.060997,4.060997,3.060997,-0.939003,-2.939003,-1.939003,-3.939003,-6.939003,-12.939003,-3.939003,10.060997,27.060997,33.060997,28.060997,15.060997,5.060997,6.060997,-14.939003,-40.939003,-39.939003,-41.939003,-47.939003,-71.939,-93.939,-98.939,-100.939,-100.939,-99.939,-53.939003,35.060997,41.060997,27.060997,-12.939003,-59.939003,-102.939,-100.939,-101.939,-102.939,-101.939,-99.939,-95.939,-93.939,-93.939,-90.939,-86.939,-81.939,-81.939,-80.939,-79.939,-76.939,-72.939,-65.939,-62.939003,-62.939003,-64.939,-64.939,-59.939003,-55.939003,-52.939003,-52.939003,-52.939003,-50.939003,-50.939003,-51.939003,-56.939003,-56.939003,-53.939003,-49.939003,-51.939003,-58.939003,-61.939003,-62.939003,-58.939003,-55.939003,-52.939003,-56.939003,-58.939003,-56.939003,-53.939003,-49.939003,-45.939003,-44.939003,-42.939003,-39.939003,-37.939003,-36.939003,-38.939003,-38.939003,-33.939003,-28.939003,-22.939003,-27.939003,-25.939003,-15.939003,-15.939003,-14.939003,-6.939003,0.06099701,4.060997,-27.939003,-52.939003,-70.939,-68.939,-64.939,-59.939003,-54.939003,-50.939003,-47.939003,-46.939003,-43.939003,-44.939003,-44.939003,-40.939003,-32.939003,-22.939003,-25.939003,-26.939003,-26.939003,-28.939003,-29.939003,-28.939003,-32.939003,-37.939003,-26.939003,-24.939003,-29.939003,-34.939003,-38.939003,-42.939003,-43.939003,-43.939003,-45.939003,-44.939003,-42.939003,-42.939003,-44.939003,-45.939003,-45.939003,-42.939003,-44.939003,-46.939003,-48.939003,-54.939003,-59.939003,-59.939003,-58.939003,-56.939003,-54.939003,-52.939003,-49.939003,-45.939003,-41.939003,-40.939003,-41.939003,-42.939003,-47.939003,-48.939003,-44.939003,-45.939003,-48.939003,-53.939003,-51.939003,-47.939003,-49.939003,-48.939003,-45.939003,-51.939003,-55.939003,-51.939003,-50.939003,-51.939003,-55.939003,-59.939003,-60.939003,-62.939003,-64.939,-63.939003,-63.939003,-64.939,-62.939003,-61.939003,-61.939003,-63.939003,-64.939,-65.939,-65.939,-65.939,-67.939,-70.939,-73.939,-73.939,-63.939003,-35.939003,-16.939003,-7.939003,-35.939003,-59.939003,-78.939,-82.939,-82.939,-78.939,-78.939,-79.939,-81.939,-84.939,-86.939,-86.939,-86.939,-88.939,-88.939,-88.939,-89.939,-91.939,-91.939,-90.939,-90.939,-92.939,-94.939,-95.939,-95.939,-95.939,-96.939,-98.939,-99.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-97.939,-95.939,-94.939,-97.939,-97.939,-50.939003,-7.939003,29.060997,14.060997,-18.939003,-78.939,-98.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-100.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-99.939,-85.939,-79.939,-82.939,-92.939,-100.939,-101.939,-100.939,-98.939,-99.939,-101.939,-101.939,-79.939,-63.939003,11.060997,13.060997,16.060997,18.060997,19.060997,19.060997,22.060997,23.060997,14.060997,-17.939003,-57.939003,-74.939,-78.939,-72.939,-73.939,-73.939,-73.939,-72.939,-70.939,-67.939,-58.939003,-44.939003,-32.939003,-25.939003,-30.939003,-24.939003,-12.939003,11.060997,20.060997,13.060997,18.060997,23.060997,23.060997,36.060997,52.060997,47.060997,34.060997,11.060997,7.060997,3.060997,-3.939003,-7.939003,-11.939003,-12.939003,-14.939003,-16.939003,-17.939003,-27.939003,-65.939,-88.939,-102.939,-99.939,-94.939,-87.939,-81.939,-77.939,-75.939,-72.939,-68.939,-65.939,-62.939003,-59.939003,-59.939003,-60.939003,-60.939003,-59.939003,-57.939003,-54.939003,-50.939003,-47.939003,-46.939003,-43.939003,-40.939003,-39.939003,-37.939003,-31.939003,-22.939003,-10.939003,-13.939003,-16.939003,-17.939003,-26.939003,-40.939003,-47.939003,-50.939003,-50.939003,-55.939003,-58.939003,-53.939003,-54.939003,-57.939003,-57.939003,-58.939003,-59.939003,-63.939003,-68.939,-72.939,-73.939,-72.939,-72.939,-72.939,-73.939,-74.939,-77.939,-80.939,-82.939,-83.939,-85.939,-87.939,-86.939,-89.939,-92.939,-95.939,-96.939,-97.939,-99.939,-100.939,-101.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-98.939,-91.939,-88.939,-89.939,-93.939,-19.939003,38.060997,38.060997,42.060997,42.060997,4.060997,-36.939003,-78.939,-75.939,-70.939,-69.939,-65.939,-59.939003,-56.939003,-55.939003,-56.939003,-58.939003,-59.939003,-57.939003,-51.939003,-45.939003,-50.939003,-51.939003,-48.939003,-51.939003,-52.939003,-47.939003,-46.939003,-46.939003,-46.939003,-47.939003,-49.939003,-47.939003,-44.939003,-44.939003,-47.939003,-51.939003,-56.939003,-59.939003,-62.939003,-57.939003,-54.939003,-54.939003,-56.939003,-58.939003,-63.939003,-65.939,-63.939003,-68.939,-72.939,-74.939,-75.939,-75.939,-78.939,-78.939,-78.939,-77.939,-78.939,-81.939,-83.939,-83.939,-85.939,-71.939,-38.939003,-26.939003,-15.939003,-3.939003,23.060997,52.060997,53.060997,53.060997,52.060997,52.060997,53.060997,55.060997,53.060997,50.060997,57.060997,59.060997,53.060997,16.060997,2.060997,50.060997,49.060997,32.060997,56.060997,65.061,57.060997,44.060997,37.060997,45.060997,44.060997,40.060997,38.060997,36.060997,33.060997,32.060997,29.060997,21.060997,18.060997,18.060997,19.060997,13.060997,-1.939003,-3.939003,-1.939003,5.060997,4.060997,1.060997,1.060997,-4.939003,-14.939003,-4.939003,9.060997,26.060997,26.060997,17.060997,13.060997,-15.939003,-67.939,-66.939,-59.939003,-54.939003,-55.939003,-57.939003,-52.939003,-50.939003,-51.939003,-55.939003,-59.939003,-60.939003,-59.939003,-57.939003,-59.939003,-60.939003,-61.939003,-62.939003,-63.939003,-63.939003,-64.939,-65.939,-67.939,-66.939,-62.939003,-65.939,-70.939,-74.939,-71.939,-66.939,-68.939,-70.939,-73.939,-72.939,-71.939,-73.939,-77.939,-77.939,-25.939003,2.060997,8.060997,15.060997,10.060997,-24.939003,-39.939003,-48.939003,-82.939,-98.939,-96.939,-91.939,-90.939,-100.939,-64.939,-8.939003,50.060997,75.061,67.061,38.060997,-4.939003,-80.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-97.939,-93.939,-92.939,-89.939,-87.939,-87.939,-87.939,-86.939,-83.939,-80.939,-79.939,-76.939,-72.939,-72.939,-72.939,-72.939,-72.939,-71.939,-69.939,-61.939003,-47.939003,-57.939003,-64.939,-60.939003,-55.939003,-50.939003,-59.939003,-60.939003,-55.939003,-52.939003,-49.939003,-48.939003,-47.939003,-46.939003,-46.939003,-44.939003,-38.939003,-41.939003,-40.939003,-31.939003,-25.939003,-21.939003,-22.939003,-29.939003,-40.939003,-45.939003,-46.939003,-44.939003,-44.939003,-45.939003,-47.939003,-48.939003,-49.939003,-47.939003,-46.939003,-48.939003,-50.939003,-53.939003,-59.939003,-62.939003,-61.939003,-58.939003,-56.939003,-57.939003,-61.939003,-64.939,-65.939,-67.939,-71.939,-69.939,-68.939,-69.939,-70.939,-71.939,-72.939,-73.939,-74.939,-75.939,-76.939,-79.939,-80.939,-79.939,-82.939,-80.939,-72.939,-47.939003,-31.939003,-36.939003,-45.939003,-53.939003,-26.939003,-9.939003,-1.939003,-22.939003,-37.939003,-27.939003,-21.939003,-17.939003,-15.939003,-11.939003,-7.939003,-7.939003,-2.939003,14.060997,25.060997,34.060997,33.060997,33.060997,32.060997,38.060997,44.060997,46.060997,33.060997,16.060997,19.060997,21.060997,22.060997,23.060997,24.060997,25.060997,26.060997,26.060997,25.060997,26.060997,28.060997,28.060997,27.060997,28.060997,29.060997,30.060997,31.060997,32.060997,33.060997,32.060997,31.060997,32.060997,34.060997,35.060997,34.060997,34.060997,34.060997,36.060997,37.060997,34.060997,32.060997,32.060997,33.060997,35.060997,37.060997,36.060997,37.060997,41.060997,42.060997,43.060997,46.060997,45.060997,42.060997,40.060997,37.060997,34.060997,35.060997,37.060997,32.060997,28.060997,25.060997,29.060997,32.060997,28.060997,27.060997,27.060997,28.060997,28.060997,27.060997,25.060997,24.060997,26.060997,25.060997,23.060997,24.060997,23.060997,20.060997,19.060997,18.060997,15.060997,13.060997,11.060997,11.060997,9.060997,5.060997,6.060997,7.060997,6.060997,5.060997,2.060997,0.06099701,-1.939003,-2.939003,-2.939003,-4.939003,-9.939003,-7.939003,-1.939003,20.060997,30.060997,28.060997,17.060997,8.060997,10.060997,-7.939003,-33.939003,-38.939003,-40.939003,-39.939003,-57.939003,-75.939,-89.939,-86.939,-78.939,-76.939,-54.939003,-10.939003,-9.939003,-17.939003,-34.939003,-53.939003,-70.939,-66.939,-66.939,-67.939,-66.939,-65.939,-64.939,-63.939003,-61.939003,-60.939003,-59.939003,-59.939003,-59.939003,-60.939003,-61.939003,-60.939003,-59.939003,-57.939003,-55.939003,-54.939003,-58.939003,-60.939003,-60.939003,-60.939003,-61.939003,-58.939003,-56.939003,-54.939003,-52.939003,-51.939003,-54.939003,-53.939003,-50.939003,-51.939003,-49.939003,-45.939003,-30.939003,-22.939003,-32.939003,-33.939003,-30.939003,-24.939003,-22.939003,-23.939003,-18.939003,-19.939003,-30.939003,-29.939003,-23.939003,-10.939003,-3.939003,-1.939003,-4.939003,-7.939003,-6.939003,-4.939003,-3.939003,-10.939003,-9.939003,-1.939003,-6.939003,-10.939003,-5.939003,-1.939003,-1.939003,-30.939003,-51.939003,-62.939003,-60.939003,-57.939003,-55.939003,-50.939003,-46.939003,-43.939003,-43.939003,-45.939003,-46.939003,-45.939003,-38.939003,-32.939003,-26.939003,-31.939003,-30.939003,-26.939003,-25.939003,-25.939003,-27.939003,-34.939003,-39.939003,2.060997,19.060997,12.060997,6.060997,0.06099701,-2.939003,-10.939003,-17.939003,-12.939003,-7.939003,-1.939003,-3.939003,-5.939003,-3.939003,-5.939003,-10.939003,-31.939003,-34.939003,-18.939003,-10.939003,-7.939003,-13.939003,-12.939003,-10.939003,-8.939003,-8.939003,-6.939003,-2.939003,-6.939003,-37.939003,-58.939003,-72.939,-68.939,-66.939,-67.939,-68.939,-69.939,-69.939,-68.939,-67.939,-65.939,-65.939,-65.939,-68.939,-69.939,-65.939,-59.939003,-52.939003,-64.939,-69.939,-68.939,-69.939,-69.939,-69.939,-69.939,-69.939,-69.939,-69.939,-68.939,-67.939,-67.939,-67.939,-67.939,-67.939,-68.939,-70.939,-72.939,-69.939,-60.939003,-39.939003,-27.939003,-21.939003,-38.939003,-54.939003,-68.939,-70.939,-70.939,-65.939,-64.939,-65.939,-64.939,-65.939,-66.939,-66.939,-66.939,-67.939,-67.939,-67.939,-69.939,-71.939,-74.939,-71.939,-69.939,-71.939,-72.939,-72.939,-72.939,-72.939,-73.939,-75.939,-75.939,-74.939,-73.939,-73.939,-75.939,-76.939,-75.939,-74.939,-73.939,-75.939,-77.939,-78.939,-54.939003,-35.939003,-20.939003,-22.939003,-35.939003,-68.939,-78.939,-78.939,-81.939,-83.939,-86.939,-85.939,-84.939,-87.939,-87.939,-87.939,-88.939,-88.939,-90.939,-88.939,-88.939,-92.939,-93.939,-92.939,-92.939,-93.939,-96.939,-95.939,-95.939,-96.939,-97.939,-96.939,-81.939,-75.939,-79.939,-91.939,-99.939,-99.939,-100.939,-101.939,-101.939,-102.939,-102.939,-81.939,-65.939,58.060997,59.060997,61.060997,60.060997,58.060997,55.060997,57.060997,54.060997,40.060997,-8.939003,-69.939,-87.939,-89.939,-77.939,-80.939,-81.939,-73.939,-64.939,-55.939003,-47.939003,-35.939003,-21.939003,-16.939003,-16.939003,-23.939003,-23.939003,-18.939003,7.060997,19.060997,18.060997,22.060997,26.060997,26.060997,36.060997,47.060997,37.060997,24.060997,9.060997,4.060997,-0.939003,-8.939003,-11.939003,-13.939003,-15.939003,-16.939003,-15.939003,-23.939003,-39.939003,-73.939,-91.939,-99.939,-97.939,-88.939,-72.939,-61.939003,-54.939003,-54.939003,-50.939003,-43.939003,-41.939003,-39.939003,-38.939003,-37.939003,-37.939003,-39.939003,-40.939003,-39.939003,-37.939003,-35.939003,-34.939003,-36.939003,-37.939003,-33.939003,-33.939003,-34.939003,-31.939003,-20.939003,-0.939003,-0.939003,2.060997,10.060997,-5.939003,-31.939003,-60.939003,-75.939,-77.939,-82.939,-85.939,-83.939,-86.939,-92.939,-93.939,-95.939,-95.939,-99.939,-102.939,-101.939,-100.939,-100.939,-100.939,-101.939,-100.939,-99.939,-99.939,-101.939,-100.939,-100.939,-99.939,-99.939,-100.939,-102.939,-103.939,-102.939,-101.939,-101.939,-101.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-96.939,-95.939,-95.939,-94.939,-93.939,-93.939,-93.939,-91.939,-89.939,-89.939,-89.939,-89.939,-88.939,-87.939,-87.939,-83.939,-76.939,-69.939,-68.939,-73.939,-27.939003,8.060997,11.060997,11.060997,5.060997,-14.939003,-35.939003,-58.939003,-52.939003,-45.939003,-45.939003,-40.939003,-32.939003,-28.939003,-28.939003,-31.939003,-35.939003,-38.939003,-37.939003,-32.939003,-29.939003,-34.939003,-38.939003,-39.939003,-45.939003,-49.939003,-44.939003,-45.939003,-48.939003,-52.939003,-57.939003,-61.939003,-62.939003,-63.939003,-66.939,-71.939,-77.939,-83.939,-88.939,-90.939,-88.939,-87.939,-91.939,-93.939,-94.939,-98.939,-100.939,-100.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-99.939,-97.939,-97.939,-100.939,-99.939,-98.939,-99.939,-81.939,-41.939003,-26.939003,-14.939003,-3.939003,20.060997,47.060997,49.060997,47.060997,42.060997,39.060997,39.060997,43.060997,40.060997,35.060997,40.060997,42.060997,39.060997,6.060997,-6.939003,31.060997,30.060997,16.060997,33.060997,40.060997,34.060997,23.060997,17.060997,25.060997,23.060997,17.060997,14.060997,12.060997,8.060997,5.060997,1.060997,-5.939003,-7.939003,-7.939003,-8.939003,-11.939003,-19.939003,-22.939003,-24.939003,-23.939003,-24.939003,-25.939003,-27.939003,-26.939003,-20.939003,-2.939003,14.060997,27.060997,27.060997,20.060997,16.060997,-9.939003,-59.939003,-57.939003,-52.939003,-53.939003,-57.939003,-60.939003,-54.939003,-53.939003,-57.939003,-64.939,-70.939,-72.939,-70.939,-69.939,-75.939,-79.939,-82.939,-84.939,-87.939,-88.939,-88.939,-90.939,-92.939,-94.939,-94.939,-99.939,-103.939,-101.939,-101.939,-101.939,-101.939,-100.939,-99.939,-98.939,-99.939,-99.939,-101.939,-96.939,-31.939003,5.060997,12.060997,16.060997,12.060997,-12.939003,-26.939003,-40.939003,-81.939,-101.939,-98.939,-97.939,-95.939,-96.939,-62.939003,-13.939003,41.060997,65.061,58.060997,30.060997,-8.939003,-72.939,-90.939,-89.939,-89.939,-88.939,-87.939,-87.939,-86.939,-85.939,-79.939,-73.939,-72.939,-69.939,-66.939,-66.939,-66.939,-66.939,-64.939,-61.939003,-59.939003,-55.939003,-50.939003,-50.939003,-50.939003,-53.939003,-51.939003,-48.939003,-45.939003,-42.939003,-40.939003,-44.939003,-44.939003,-39.939003,-37.939003,-37.939003,-39.939003,-38.939003,-34.939003,-31.939003,-29.939003,-31.939003,-31.939003,-29.939003,-31.939003,-30.939003,-25.939003,-29.939003,-33.939003,-33.939003,-27.939003,-21.939003,-18.939003,-23.939003,-37.939003,-46.939003,-53.939003,-53.939003,-54.939003,-57.939003,-62.939003,-66.939,-68.939,-66.939,-66.939,-70.939,-74.939,-77.939,-83.939,-87.939,-88.939,-88.939,-89.939,-90.939,-92.939,-93.939,-95.939,-98.939,-101.939,-102.939,-102.939,-103.939,-102.939,-101.939,-101.939,-101.939,-101.939,-99.939,-98.939,-98.939,-97.939,-95.939,-96.939,-90.939,-78.939,-45.939003,-22.939003,-31.939003,-41.939003,-47.939003,-18.939003,-8.939003,-19.939003,-26.939003,-29.939003,-24.939003,-17.939003,-12.939003,-8.939003,-5.939003,-3.939003,-5.939003,-0.939003,19.060997,29.060997,35.060997,35.060997,33.060997,31.060997,41.060997,47.060997,39.060997,25.060997,13.060997,19.060997,23.060997,22.060997,23.060997,24.060997,25.060997,26.060997,27.060997,26.060997,27.060997,29.060997,28.060997,28.060997,28.060997,29.060997,30.060997,32.060997,33.060997,34.060997,33.060997,32.060997,33.060997,34.060997,35.060997,36.060997,36.060997,34.060997,36.060997,36.060997,33.060997,33.060997,34.060997,33.060997,34.060997,37.060997,36.060997,36.060997,41.060997,42.060997,42.060997,46.060997,46.060997,42.060997,41.060997,39.060997,35.060997,36.060997,38.060997,32.060997,28.060997,25.060997,30.060997,32.060997,29.060997,28.060997,26.060997,28.060997,28.060997,26.060997,25.060997,24.060997,25.060997,25.060997,24.060997,24.060997,23.060997,20.060997,19.060997,18.060997,16.060997,14.060997,12.060997,11.060997,9.060997,5.060997,6.060997,6.060997,7.060997,5.060997,1.060997,1.060997,0.06099701,-2.939003,-2.939003,-3.939003,-5.939003,-9.939003,-10.939003,12.060997,25.060997,30.060997,20.060997,12.060997,13.060997,-0.939003,-20.939003,-35.939003,-39.939003,-34.939003,-44.939003,-58.939003,-76.939,-72.939,-59.939003,-59.939003,-54.939003,-46.939003,-45.939003,-46.939003,-48.939003,-51.939003,-53.939003,-48.939003,-47.939003,-49.939003,-49.939003,-49.939003,-51.939003,-50.939003,-49.939003,-49.939003,-51.939003,-53.939003,-54.939003,-56.939003,-58.939003,-60.939003,-61.939003,-62.939003,-62.939003,-61.939003,-64.939,-68.939,-70.939,-74.939,-76.939,-67.939,-59.939003,-51.939003,-49.939003,-49.939003,-51.939003,-49.939003,-47.939003,-53.939003,-50.939003,-37.939003,-5.939003,11.060997,-10.939003,-18.939003,-16.939003,-4.939003,-0.939003,-2.939003,1.060997,-2.939003,-22.939003,-23.939003,-16.939003,-0.939003,8.060997,10.060997,7.060997,2.060997,-0.939003,-2.939003,-5.939003,-12.939003,-12.939003,-7.939003,-15.939003,-20.939003,-17.939003,-17.939003,-18.939003,-37.939003,-49.939003,-52.939003,-50.939003,-47.939003,-48.939003,-44.939003,-41.939003,-39.939003,-40.939003,-45.939003,-48.939003,-46.939003,-35.939003,-31.939003,-29.939003,-33.939003,-32.939003,-26.939003,-24.939003,-24.939003,-27.939003,-35.939003,-38.939003,17.060997,40.060997,33.060997,26.060997,21.060997,18.060997,7.060997,-4.939003,5.060997,15.060997,23.060997,20.060997,17.060997,22.060997,20.060997,12.060997,-24.939003,-27.939003,3.060997,21.060997,30.060997,20.060997,20.060997,23.060997,24.060997,25.060997,25.060997,30.060997,18.060997,-37.939003,-73.939,-97.939,-86.939,-83.939,-87.939,-89.939,-89.939,-85.939,-86.939,-86.939,-82.939,-83.939,-86.939,-86.939,-85.939,-81.939,-69.939,-57.939003,-76.939,-83.939,-81.939,-81.939,-81.939,-81.939,-81.939,-80.939,-82.939,-82.939,-80.939,-79.939,-78.939,-77.939,-77.939,-76.939,-77.939,-78.939,-79.939,-71.939,-56.939003,-29.939003,-17.939003,-15.939003,-37.939003,-55.939003,-68.939,-69.939,-69.939,-65.939,-64.939,-64.939,-62.939003,-61.939003,-62.939003,-62.939003,-62.939003,-61.939003,-61.939003,-61.939003,-63.939003,-65.939,-68.939,-65.939,-62.939003,-64.939,-64.939,-63.939003,-62.939003,-62.939003,-63.939003,-64.939,-64.939,-60.939003,-59.939003,-58.939003,-62.939003,-63.939003,-62.939003,-61.939003,-62.939003,-65.939,-66.939,-67.939,-58.939003,-54.939003,-54.939003,-48.939003,-47.939003,-62.939003,-64.939,-61.939003,-65.939,-69.939,-74.939,-72.939,-70.939,-74.939,-74.939,-73.939,-74.939,-75.939,-78.939,-76.939,-75.939,-79.939,-80.939,-79.939,-79.939,-81.939,-85.939,-84.939,-83.939,-85.939,-86.939,-86.939,-73.939,-69.939,-73.939,-83.939,-90.939,-90.939,-92.939,-94.939,-93.939,-93.939,-94.939,-78.939,-65.939,52.060997,52.060997,52.060997,51.060997,51.060997,50.060997,55.060997,56.060997,49.060997,-3.939003,-71.939,-89.939,-91.939,-79.939,-77.939,-69.939,-51.939003,-36.939003,-21.939003,-15.939003,-13.939003,-16.939003,-18.939003,-18.939003,-14.939003,-20.939003,-27.939003,-1.939003,18.060997,27.060997,27.060997,25.060997,17.060997,11.060997,6.060997,2.060997,2.060997,8.060997,3.060997,-3.939003,-9.939003,-13.939003,-15.939003,-19.939003,-19.939003,-16.939003,-31.939003,-53.939003,-85.939,-91.939,-87.939,-86.939,-76.939,-57.939003,-50.939003,-48.939003,-54.939003,-51.939003,-47.939003,-53.939003,-61.939003,-69.939,-70.939,-72.939,-75.939,-78.939,-81.939,-82.939,-84.939,-83.939,-85.939,-85.939,-86.939,-86.939,-85.939,-69.939,-22.939003,54.060997,48.060997,43.060997,61.060997,29.060997,-23.939003,-72.939,-96.939,-97.939,-97.939,-97.939,-95.939,-97.939,-100.939,-101.939,-100.939,-99.939,-99.939,-100.939,-101.939,-102.939,-102.939,-102.939,-100.939,-98.939,-100.939,-102.939,-102.939,-101.939,-99.939,-101.939,-101.939,-98.939,-100.939,-101.939,-97.939,-95.939,-93.939,-95.939,-95.939,-94.939,-91.939,-87.939,-79.939,-75.939,-73.939,-70.939,-67.939,-65.939,-61.939003,-58.939003,-58.939003,-53.939003,-46.939003,-44.939003,-42.939003,-41.939003,-37.939003,-34.939003,-34.939003,-31.939003,-27.939003,-24.939003,-23.939003,-25.939003,-22.939003,-21.939003,-24.939003,-23.939003,-23.939003,-23.939003,-30.939003,-43.939003,-47.939003,-50.939003,-54.939003,-57.939003,-60.939003,-58.939003,-60.939003,-66.939,-72.939,-76.939,-75.939,-77.939,-81.939,-86.939,-88.939,-87.939,-90.939,-90.939,-89.939,-90.939,-90.939,-91.939,-92.939,-93.939,-94.939,-94.939,-95.939,-96.939,-97.939,-98.939,-99.939,-100.939,-99.939,-99.939,-100.939,-101.939,-100.939,-100.939,-100.939,-98.939,-100.939,-103.939,-103.939,-101.939,-100.939,-101.939,-99.939,-95.939,-93.939,-92.939,-93.939,-94.939,-96.939,-94.939,-77.939,-46.939003,-32.939003,-19.939003,-4.939003,12.060997,26.060997,24.060997,21.060997,16.060997,10.060997,3.060997,-2.939003,-4.939003,-3.939003,-1.939003,-4.939003,-11.939003,-21.939003,-25.939003,-13.939003,-18.939003,-26.939003,-25.939003,-24.939003,-24.939003,-26.939003,-27.939003,-26.939003,-23.939003,-21.939003,-24.939003,-26.939003,-24.939003,-21.939003,-18.939003,-14.939003,-12.939003,-9.939003,-5.939003,-6.939003,-14.939003,-12.939003,-5.939003,8.060997,11.060997,10.060997,20.060997,15.060997,-5.939003,-1.939003,9.060997,29.060997,28.060997,20.060997,18.060997,-11.939003,-68.939,-83.939,-91.939,-91.939,-92.939,-93.939,-92.939,-92.939,-92.939,-94.939,-95.939,-96.939,-95.939,-95.939,-96.939,-97.939,-98.939,-98.939,-98.939,-99.939,-100.939,-100.939,-101.939,-101.939,-100.939,-100.939,-101.939,-102.939,-103.939,-102.939,-99.939,-98.939,-97.939,-99.939,-100.939,-101.939,-102.939,-96.939,-29.939003,6.060997,14.060997,16.060997,13.060997,2.060997,-15.939003,-36.939003,-78.939,-94.939,-83.939,-78.939,-75.939,-73.939,-54.939003,-28.939003,0.06099701,10.060997,1.060997,-15.939003,-31.939003,-45.939003,-46.939003,-44.939003,-41.939003,-38.939003,-34.939003,-35.939003,-35.939003,-36.939003,-34.939003,-32.939003,-33.939003,-35.939003,-35.939003,-35.939003,-34.939003,-33.939003,-32.939003,-31.939003,-34.939003,-36.939003,-35.939003,-36.939003,-39.939003,-44.939003,-46.939003,-47.939003,-46.939003,-43.939003,-39.939003,-52.939003,-61.939003,-56.939003,-53.939003,-53.939003,-64.939,-70.939,-69.939,-71.939,-74.939,-79.939,-77.939,-73.939,-76.939,-75.939,-72.939,-81.939,-68.939,-2.939003,30.060997,50.060997,54.060997,22.060997,-45.939003,-75.939,-91.939,-91.939,-92.939,-92.939,-94.939,-95.939,-95.939,-95.939,-95.939,-95.939,-96.939,-97.939,-98.939,-99.939,-100.939,-98.939,-98.939,-99.939,-100.939,-101.939,-100.939,-100.939,-99.939,-99.939,-100.939,-101.939,-100.939,-96.939,-96.939,-95.939,-94.939,-90.939,-87.939,-85.939,-82.939,-78.939,-74.939,-68.939,-62.939003,-43.939003,-31.939003,-36.939003,-35.939003,-29.939003,-8.939003,-10.939003,-34.939003,-26.939003,-18.939003,-15.939003,-15.939003,-14.939003,-10.939003,-4.939003,1.060997,3.060997,7.060997,24.060997,29.060997,30.060997,33.060997,35.060997,35.060997,46.060997,49.060997,30.060997,19.060997,14.060997,20.060997,21.060997,18.060997,21.060997,23.060997,24.060997,26.060997,27.060997,26.060997,26.060997,27.060997,30.060997,31.060997,30.060997,30.060997,30.060997,30.060997,31.060997,32.060997,33.060997,33.060997,35.060997,35.060997,35.060997,37.060997,37.060997,35.060997,37.060997,38.060997,37.060997,38.060997,39.060997,36.060997,36.060997,39.060997,37.060997,38.060997,41.060997,43.060997,44.060997,47.060997,47.060997,43.060997,42.060997,41.060997,38.060997,38.060997,39.060997,33.060997,29.060997,27.060997,31.060997,33.060997,32.060997,29.060997,28.060997,30.060997,29.060997,24.060997,24.060997,25.060997,24.060997,24.060997,25.060997,25.060997,23.060997,20.060997,19.060997,19.060997,17.060997,16.060997,14.060997,12.060997,9.060997,6.060997,4.060997,5.060997,6.060997,4.060997,2.060997,4.060997,2.060997,-3.939003,-4.939003,-4.939003,-4.939003,-8.939003,-13.939003,2.060997,18.060997,35.060997,26.060997,17.060997,12.060997,8.060997,2.060997,-27.939003,-39.939003,-34.939003,-42.939003,-53.939003,-66.939,-67.939,-64.939,-64.939,-54.939003,-34.939003,-24.939003,-21.939003,-35.939003,-59.939003,-84.939,-78.939,-77.939,-80.939,-81.939,-83.939,-87.939,-88.939,-88.939,-90.939,-91.939,-92.939,-92.939,-92.939,-93.939,-93.939,-93.939,-93.939,-93.939,-93.939,-94.939,-95.939,-95.939,-96.939,-95.939,-76.939,-56.939003,-36.939003,-40.939003,-45.939003,-47.939003,-47.939003,-46.939003,-58.939003,-57.939003,-42.939003,-7.939003,10.060997,-14.939003,-26.939003,-31.939003,-28.939003,-25.939003,-23.939003,-25.939003,-29.939003,-37.939003,-41.939003,-43.939003,-40.939003,-38.939003,-37.939003,-40.939003,-43.939003,-50.939003,-53.939003,-55.939003,-56.939003,-57.939003,-59.939003,-58.939003,-58.939003,-56.939003,-52.939003,-48.939003,-49.939003,-48.939003,-47.939003,-42.939003,-39.939003,-40.939003,-39.939003,-37.939003,-36.939003,-38.939003,-43.939003,-46.939003,-45.939003,-34.939003,-29.939003,-27.939003,-28.939003,-27.939003,-25.939003,-28.939003,-30.939003,-29.939003,-31.939003,-33.939003,-11.939003,-1.939003,-5.939003,-9.939003,-13.939003,-13.939003,-20.939003,-29.939003,-19.939003,-9.939003,2.060997,-4.939003,-8.939003,-6.939003,-4.939003,-6.939003,-35.939003,-35.939003,-6.939003,5.060997,11.060997,5.060997,5.060997,8.060997,9.060997,10.060997,12.060997,15.060997,3.060997,-42.939003,-72.939,-92.939,-88.939,-87.939,-89.939,-90.939,-93.939,-94.939,-91.939,-87.939,-89.939,-91.939,-95.939,-96.939,-94.939,-86.939,-75.939,-66.939,-88.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-97.939,-97.939,-97.939,-97.939,-97.939,-97.939,-97.939,-97.939,-82.939,-51.939003,12.060997,33.060997,29.060997,-32.939003,-73.939,-93.939,-94.939,-94.939,-94.939,-94.939,-94.939,-93.939,-93.939,-93.939,-93.939,-93.939,-93.939,-92.939,-91.939,-92.939,-93.939,-95.939,-94.939,-94.939,-93.939,-92.939,-92.939,-89.939,-87.939,-85.939,-88.939,-89.939,-84.939,-82.939,-80.939,-82.939,-82.939,-80.939,-80.939,-81.939,-85.939,-84.939,-79.939,-59.939003,-43.939003,-29.939003,-31.939003,-42.939003,-71.939,-75.939,-69.939,-70.939,-72.939,-75.939,-73.939,-72.939,-74.939,-71.939,-68.939,-69.939,-70.939,-71.939,-67.939,-65.939,-66.939,-65.939,-62.939003,-62.939003,-65.939,-70.939,-68.939,-65.939,-64.939,-64.939,-64.939,-60.939003,-60.939003,-64.939,-66.939,-68.939,-68.939,-69.939,-69.939,-66.939,-66.939,-69.939,-64.939,-60.939003,44.060997,42.060997,40.060997,36.060997,34.060997,31.060997,32.060997,29.060997,21.060997,-14.939003,-59.939003,-71.939,-72.939,-63.939003,-58.939003,-49.939003,-34.939003,-23.939003,-15.939003,-13.939003,-13.939003,-15.939003,-18.939003,-19.939003,-13.939003,-19.939003,-26.939003,-6.939003,11.060997,27.060997,21.060997,12.060997,1.060997,-5.939003,-10.939003,-13.939003,-12.939003,-7.939003,-8.939003,-10.939003,-13.939003,-15.939003,-17.939003,-18.939003,-19.939003,-20.939003,-43.939003,-68.939,-92.939,-96.939,-92.939,-92.939,-85.939,-72.939,-70.939,-70.939,-74.939,-73.939,-72.939,-77.939,-82.939,-89.939,-90.939,-91.939,-93.939,-96.939,-97.939,-98.939,-99.939,-99.939,-99.939,-99.939,-98.939,-99.939,-101.939,-87.939,-32.939003,63.060997,59.060997,55.060997,72.061,42.060997,-8.939003,-72.939,-103.939,-102.939,-102.939,-102.939,-100.939,-101.939,-103.939,-100.939,-98.939,-97.939,-95.939,-93.939,-92.939,-91.939,-89.939,-88.939,-85.939,-82.939,-82.939,-82.939,-80.939,-77.939,-74.939,-73.939,-73.939,-70.939,-70.939,-70.939,-67.939,-64.939,-62.939003,-63.939003,-62.939003,-60.939003,-57.939003,-55.939003,-52.939003,-52.939003,-54.939003,-51.939003,-49.939003,-48.939003,-47.939003,-47.939003,-51.939003,-48.939003,-45.939003,-45.939003,-46.939003,-47.939003,-44.939003,-43.939003,-46.939003,-46.939003,-45.939003,-42.939003,-44.939003,-50.939003,-25.939003,-2.939003,11.060997,18.060997,19.060997,11.060997,-16.939003,-62.939003,-71.939,-74.939,-74.939,-77.939,-81.939,-81.939,-84.939,-88.939,-91.939,-94.939,-94.939,-96.939,-100.939,-102.939,-103.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-101.939,-100.939,-100.939,-99.939,-97.939,-94.939,-93.939,-92.939,-90.939,-89.939,-86.939,-85.939,-84.939,-82.939,-81.939,-81.939,-79.939,-76.939,-71.939,-69.939,-67.939,-63.939003,-63.939003,-64.939,-63.939003,-54.939003,-38.939003,-33.939003,-26.939003,-14.939003,-7.939003,-1.939003,-2.939003,-3.939003,-7.939003,-8.939003,-11.939003,-13.939003,-12.939003,-9.939003,-7.939003,-6.939003,-6.939003,-16.939003,-19.939003,-2.939003,-4.939003,-12.939003,-3.939003,1.060997,2.060997,-2.939003,-4.939003,3.060997,6.060997,7.060997,5.060997,4.060997,5.060997,9.060997,12.060997,13.060997,13.060997,14.060997,22.060997,20.060997,9.060997,4.060997,9.060997,32.060997,36.060997,31.060997,45.060997,36.060997,5.060997,0.06099701,5.060997,27.060997,28.060997,20.060997,17.060997,-12.939003,-70.939,-91.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-100.939,-98.939,-98.939,-96.939,-93.939,-91.939,-89.939,-88.939,-87.939,-85.939,-85.939,-85.939,-83.939,-78.939,-74.939,-72.939,-71.939,-71.939,-70.939,-72.939,-70.939,-24.939003,2.060997,11.060997,23.060997,26.060997,8.060997,-12.939003,-34.939003,-64.939,-68.939,-46.939003,-45.939003,-46.939003,-44.939003,-39.939003,-32.939003,-20.939003,-15.939003,-18.939003,-25.939003,-32.939003,-39.939003,-39.939003,-37.939003,-39.939003,-39.939003,-38.939003,-38.939003,-39.939003,-42.939003,-41.939003,-39.939003,-43.939003,-48.939003,-53.939003,-53.939003,-53.939003,-51.939003,-52.939003,-53.939003,-58.939003,-60.939003,-60.939003,-60.939003,-63.939003,-68.939,-70.939,-71.939,-71.939,-63.939003,-47.939003,-67.939,-82.939,-79.939,-74.939,-71.939,-83.939,-89.939,-90.939,-92.939,-94.939,-96.939,-95.939,-93.939,-95.939,-94.939,-91.939,-98.939,-79.939,0.06099701,44.060997,70.061,70.061,31.060997,-46.939003,-82.939,-102.939,-97.939,-97.939,-98.939,-94.939,-92.939,-90.939,-89.939,-87.939,-88.939,-88.939,-87.939,-86.939,-84.939,-82.939,-79.939,-78.939,-77.939,-75.939,-73.939,-73.939,-73.939,-71.939,-69.939,-68.939,-68.939,-66.939,-63.939003,-62.939003,-62.939003,-61.939003,-58.939003,-54.939003,-53.939003,-52.939003,-52.939003,-49.939003,-46.939003,-43.939003,-39.939003,-37.939003,-42.939003,-28.939003,-9.939003,-13.939003,-22.939003,-36.939003,-23.939003,-13.939003,-12.939003,-10.939003,-9.939003,-7.939003,-3.939003,0.06099701,7.060997,16.060997,28.060997,31.060997,30.060997,34.060997,37.060997,40.060997,45.060997,43.060997,25.060997,18.060997,17.060997,20.060997,21.060997,17.060997,20.060997,22.060997,25.060997,27.060997,28.060997,27.060997,27.060997,28.060997,31.060997,33.060997,32.060997,31.060997,31.060997,30.060997,30.060997,31.060997,33.060997,33.060997,35.060997,35.060997,35.060997,37.060997,37.060997,35.060997,36.060997,37.060997,38.060997,39.060997,39.060997,37.060997,36.060997,38.060997,38.060997,38.060997,41.060997,43.060997,44.060997,48.060997,48.060997,44.060997,43.060997,42.060997,40.060997,40.060997,41.060997,34.060997,29.060997,26.060997,31.060997,34.060997,33.060997,31.060997,30.060997,30.060997,29.060997,25.060997,26.060997,28.060997,26.060997,26.060997,26.060997,25.060997,23.060997,21.060997,20.060997,20.060997,18.060997,17.060997,16.060997,13.060997,10.060997,8.060997,7.060997,7.060997,6.060997,4.060997,1.060997,2.060997,0.06099701,-3.939003,-3.939003,-3.939003,-3.939003,-6.939003,-10.939003,-5.939003,8.060997,30.060997,29.060997,23.060997,12.060997,11.060997,11.060997,-21.939003,-36.939003,-36.939003,-39.939003,-46.939003,-57.939003,-69.939,-81.939,-79.939,-50.939003,3.060997,13.060997,7.060997,-26.939003,-63.939003,-97.939,-94.939,-93.939,-96.939,-96.939,-97.939,-98.939,-98.939,-97.939,-98.939,-98.939,-96.939,-95.939,-93.939,-92.939,-92.939,-92.939,-89.939,-87.939,-86.939,-88.939,-88.939,-85.939,-84.939,-80.939,-65.939,-52.939003,-40.939003,-43.939003,-47.939003,-49.939003,-48.939003,-47.939003,-52.939003,-52.939003,-43.939003,-29.939003,-21.939003,-29.939003,-38.939003,-45.939003,-47.939003,-47.939003,-44.939003,-43.939003,-43.939003,-45.939003,-46.939003,-47.939003,-47.939003,-45.939003,-43.939003,-41.939003,-41.939003,-46.939003,-47.939003,-47.939003,-44.939003,-42.939003,-42.939003,-43.939003,-42.939003,-34.939003,-30.939003,-28.939003,-44.939003,-52.939003,-53.939003,-48.939003,-44.939003,-45.939003,-44.939003,-43.939003,-42.939003,-41.939003,-43.939003,-46.939003,-44.939003,-34.939003,-30.939003,-28.939003,-27.939003,-27.939003,-29.939003,-32.939003,-33.939003,-31.939003,-33.939003,-35.939003,-22.939003,-18.939003,-25.939003,-27.939003,-29.939003,-29.939003,-32.939003,-37.939003,-35.939003,-30.939003,-20.939003,-25.939003,-29.939003,-28.939003,-26.939003,-25.939003,-42.939003,-43.939003,-29.939003,-27.939003,-25.939003,-23.939003,-25.939003,-26.939003,-26.939003,-25.939003,-21.939003,-18.939003,-23.939003,-45.939003,-61.939003,-71.939,-70.939,-70.939,-71.939,-73.939,-75.939,-76.939,-74.939,-72.939,-74.939,-76.939,-79.939,-79.939,-78.939,-74.939,-69.939,-64.939,-78.939,-85.939,-87.939,-85.939,-85.939,-84.939,-84.939,-83.939,-85.939,-85.939,-85.939,-85.939,-86.939,-86.939,-86.939,-86.939,-89.939,-91.939,-92.939,-79.939,-53.939003,0.06099701,17.060997,14.060997,-39.939003,-74.939,-91.939,-95.939,-96.939,-95.939,-93.939,-93.939,-95.939,-97.939,-97.939,-98.939,-99.939,-100.939,-99.939,-96.939,-96.939,-97.939,-100.939,-101.939,-102.939,-101.939,-101.939,-101.939,-99.939,-97.939,-94.939,-97.939,-99.939,-96.939,-95.939,-94.939,-95.939,-95.939,-93.939,-93.939,-94.939,-96.939,-95.939,-89.939,-49.939003,-20.939003,-0.939003,-13.939003,-38.939003,-80.939,-89.939,-85.939,-86.939,-86.939,-88.939,-87.939,-86.939,-87.939,-84.939,-82.939,-83.939,-83.939,-84.939,-82.939,-80.939,-80.939,-79.939,-77.939,-76.939,-78.939,-82.939,-80.939,-78.939,-78.939,-75.939,-71.939,-64.939,-66.939,-76.939,-76.939,-75.939,-76.939,-76.939,-75.939,-76.939,-76.939,-76.939,-64.939,-56.939003,31.060997,27.060997,21.060997,15.060997,10.060997,6.060997,1.060997,-5.939003,-17.939003,-29.939003,-41.939003,-46.939003,-46.939003,-41.939003,-34.939003,-26.939003,-20.939003,-19.939003,-20.939003,-21.939003,-20.939003,-15.939003,-18.939003,-20.939003,-15.939003,-18.939003,-22.939003,-12.939003,1.060997,18.060997,9.060997,-3.939003,-16.939003,-18.939003,-16.939003,-20.939003,-24.939003,-28.939003,-23.939003,-19.939003,-18.939003,-20.939003,-22.939003,-17.939003,-20.939003,-29.939003,-58.939003,-84.939,-98.939,-103.939,-103.939,-103.939,-101.939,-96.939,-99.939,-101.939,-99.939,-100.939,-103.939,-102.939,-102.939,-103.939,-102.939,-102.939,-103.939,-103.939,-103.939,-100.939,-99.939,-100.939,-98.939,-96.939,-91.939,-94.939,-98.939,-92.939,-43.939003,50.060997,52.060997,51.060997,61.060997,42.060997,5.060997,-64.939,-98.939,-95.939,-97.939,-98.939,-97.939,-97.939,-97.939,-91.939,-88.939,-89.939,-86.939,-81.939,-77.939,-73.939,-69.939,-66.939,-64.939,-61.939003,-59.939003,-55.939003,-51.939003,-46.939003,-42.939003,-40.939003,-38.939003,-36.939003,-35.939003,-34.939003,-32.939003,-31.939003,-29.939003,-28.939003,-26.939003,-22.939003,-22.939003,-23.939003,-28.939003,-35.939003,-43.939003,-41.939003,-41.939003,-42.939003,-45.939003,-50.939003,-57.939003,-61.939003,-63.939003,-67.939,-72.939,-75.939,-75.939,-77.939,-82.939,-86.939,-88.939,-83.939,-87.939,-100.939,-31.939003,33.060997,72.061,85.061,83.061,60.060997,1.060997,-93.939,-103.939,-102.939,-97.939,-96.939,-97.939,-99.939,-101.939,-103.939,-101.939,-101.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-99.939,-99.939,-98.939,-99.939,-99.939,-95.939,-93.939,-92.939,-89.939,-85.939,-79.939,-78.939,-78.939,-75.939,-72.939,-70.939,-64.939,-58.939003,-55.939003,-56.939003,-57.939003,-53.939003,-49.939003,-44.939003,-41.939003,-39.939003,-31.939003,-29.939003,-28.939003,-29.939003,-29.939003,-27.939003,-31.939003,-32.939003,-27.939003,-25.939003,-25.939003,-21.939003,-20.939003,-21.939003,-17.939003,-11.939003,-4.939003,-1.939003,1.060997,5.060997,12.060997,22.060997,1.060997,-5.939003,31.060997,32.060997,22.060997,48.060997,60.060997,59.060997,48.060997,43.060997,60.060997,62.060997,60.060997,58.060997,57.060997,56.060997,59.060997,59.060997,52.060997,49.060997,48.060997,56.060997,53.060997,38.060997,23.060997,21.060997,51.060997,53.060997,43.060997,57.060997,46.060997,12.060997,1.060997,1.060997,24.060997,26.060997,20.060997,16.060997,-13.939003,-68.939,-89.939,-100.939,-98.939,-98.939,-99.939,-99.939,-99.939,-98.939,-98.939,-98.939,-97.939,-97.939,-97.939,-96.939,-96.939,-97.939,-92.939,-88.939,-86.939,-81.939,-76.939,-71.939,-67.939,-66.939,-65.939,-62.939003,-61.939003,-59.939003,-56.939003,-49.939003,-44.939003,-39.939003,-36.939003,-33.939003,-32.939003,-35.939003,-38.939003,-18.939003,-3.939003,6.060997,32.060997,42.060997,11.060997,-12.939003,-32.939003,-50.939003,-42.939003,-9.939003,-14.939003,-21.939003,-20.939003,-25.939003,-29.939003,-25.939003,-20.939003,-16.939003,-16.939003,-23.939003,-44.939003,-50.939003,-50.939003,-57.939003,-63.939003,-65.939,-65.939,-66.939,-71.939,-70.939,-67.939,-71.939,-79.939,-87.939,-90.939,-91.939,-88.939,-90.939,-93.939,-96.939,-98.939,-97.939,-96.939,-97.939,-101.939,-102.939,-102.939,-101.939,-87.939,-58.939003,-81.939,-101.939,-100.939,-93.939,-85.939,-96.939,-100.939,-100.939,-100.939,-99.939,-95.939,-96.939,-98.939,-97.939,-96.939,-93.939,-95.939,-77.939,-8.939003,33.060997,61.060997,54.060997,19.060997,-42.939003,-76.939,-95.939,-85.939,-84.939,-85.939,-78.939,-73.939,-70.939,-67.939,-64.939,-66.939,-65.939,-65.939,-62.939003,-57.939003,-52.939003,-50.939003,-47.939003,-44.939003,-40.939003,-36.939003,-37.939003,-37.939003,-38.939003,-34.939003,-31.939003,-29.939003,-27.939003,-25.939003,-26.939003,-27.939003,-26.939003,-24.939003,-21.939003,-22.939003,-25.939003,-29.939003,-29.939003,-29.939003,-29.939003,-36.939003,-43.939003,-47.939003,-21.939003,8.060997,-23.939003,-37.939003,-31.939003,-20.939003,-12.939003,-10.939003,-6.939003,-1.939003,-4.939003,-4.939003,-2.939003,10.060997,23.060997,31.060997,33.060997,32.060997,35.060997,40.060997,46.060997,42.060997,35.060997,22.060997,19.060997,20.060997,21.060997,20.060997,18.060997,19.060997,21.060997,26.060997,28.060997,29.060997,28.060997,29.060997,30.060997,33.060997,35.060997,34.060997,33.060997,32.060997,30.060997,29.060997,30.060997,32.060997,34.060997,35.060997,36.060997,36.060997,35.060997,35.060997,34.060997,34.060997,35.060997,37.060997,38.060997,37.060997,36.060997,36.060997,36.060997,38.060997,39.060997,41.060997,43.060997,45.060997,48.060997,48.060997,45.060997,44.060997,43.060997,42.060997,42.060997,42.060997,35.060997,29.060997,24.060997,29.060997,33.060997,33.060997,33.060997,32.060997,31.060997,30.060997,27.060997,29.060997,31.060997,29.060997,28.060997,26.060997,24.060997,23.060997,22.060997,21.060997,21.060997,19.060997,18.060997,17.060997,14.060997,11.060997,10.060997,11.060997,11.060997,7.060997,3.060997,0.06099701,-1.939003,-2.939003,-2.939003,-1.939003,-0.939003,-2.939003,-4.939003,-5.939003,-11.939003,-1.939003,22.060997,30.060997,29.060997,14.060997,13.060997,13.060997,-14.939003,-31.939003,-38.939003,-37.939003,-39.939003,-49.939003,-73.939,-99.939,-95.939,-47.939003,43.060997,47.060997,28.060997,-21.939003,-64.939,-99.939,-99.939,-99.939,-99.939,-99.939,-98.939,-95.939,-92.939,-89.939,-89.939,-87.939,-84.939,-81.939,-78.939,-76.939,-77.939,-75.939,-69.939,-66.939,-64.939,-67.939,-67.939,-62.939003,-59.939003,-55.939003,-50.939003,-49.939003,-52.939003,-51.939003,-50.939003,-52.939003,-51.939003,-49.939003,-45.939003,-43.939003,-41.939003,-54.939003,-58.939003,-46.939003,-47.939003,-54.939003,-59.939003,-60.939003,-57.939003,-51.939003,-46.939003,-47.939003,-44.939003,-40.939003,-37.939003,-33.939003,-28.939003,-22.939003,-18.939003,-19.939003,-18.939003,-17.939003,-12.939003,-6.939003,-1.939003,-6.939003,-7.939003,6.060997,11.060997,8.060997,-33.939003,-57.939003,-62.939003,-58.939003,-54.939003,-52.939003,-51.939003,-50.939003,-49.939003,-47.939003,-44.939003,-45.939003,-43.939003,-36.939003,-32.939003,-30.939003,-27.939003,-28.939003,-35.939003,-36.939003,-35.939003,-33.939003,-37.939003,-40.939003,-22.939003,-18.939003,-31.939003,-31.939003,-30.939003,-32.939003,-33.939003,-36.939003,-40.939003,-41.939003,-35.939003,-37.939003,-39.939003,-41.939003,-41.939003,-40.939003,-44.939003,-48.939003,-49.939003,-56.939003,-57.939003,-49.939003,-52.939003,-57.939003,-59.939003,-58.939003,-53.939003,-49.939003,-47.939003,-48.939003,-49.939003,-48.939003,-50.939003,-51.939003,-52.939003,-53.939003,-53.939003,-52.939003,-52.939003,-54.939003,-55.939003,-56.939003,-57.939003,-57.939003,-57.939003,-59.939003,-59.939003,-59.939003,-60.939003,-63.939003,-67.939,-64.939,-62.939003,-61.939003,-60.939003,-59.939003,-62.939003,-63.939003,-61.939003,-63.939003,-64.939,-65.939,-65.939,-65.939,-70.939,-75.939,-78.939,-70.939,-57.939003,-30.939003,-22.939003,-23.939003,-51.939003,-69.939,-78.939,-84.939,-87.939,-83.939,-80.939,-79.939,-83.939,-87.939,-88.939,-90.939,-93.939,-95.939,-93.939,-89.939,-88.939,-90.939,-95.939,-96.939,-97.939,-97.939,-98.939,-100.939,-99.939,-96.939,-93.939,-97.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-101.939,-101.939,-101.939,-95.939,-37.939003,1.060997,23.060997,-0.939003,-35.939003,-86.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-101.939,-101.939,-101.939,-101.939,-101.939,-102.939,-102.939,-101.939,-100.939,-101.939,-102.939,-102.939,-102.939,-103.939,-98.939,-90.939,-75.939,-77.939,-96.939,-96.939,-94.939,-94.939,-93.939,-92.939,-99.939,-100.939,-95.939,-70.939,-51.939003,-22.939003,-23.939003,-26.939003,-27.939003,-28.939003,-28.939003,-21.939003,-16.939003,-14.939003,-22.939003,-32.939003,-33.939003,-33.939003,-34.939003,-26.939003,-19.939003,-18.939003,-19.939003,-21.939003,-19.939003,-18.939003,-18.939003,-16.939003,-16.939003,-18.939003,-21.939003,-25.939003,-29.939003,-28.939003,-22.939003,-12.939003,-7.939003,-12.939003,-15.939003,-18.939003,-23.939003,-29.939003,-34.939003,-35.939003,-35.939003,-36.939003,-37.939003,-39.939003,-40.939003,-49.939003,-62.939003,-81.939,-97.939,-101.939,-103.939,-103.939,-103.939,-102.939,-101.939,-102.939,-102.939,-102.939,-101.939,-102.939,-102.939,-103.939,-103.939,-102.939,-102.939,-103.939,-103.939,-102.939,-98.939,-94.939,-93.939,-91.939,-89.939,-84.939,-82.939,-81.939,-73.939,-41.939003,17.060997,16.060997,11.060997,10.060997,-2.939003,-23.939003,-49.939003,-58.939003,-50.939003,-49.939003,-49.939003,-46.939003,-46.939003,-46.939003,-40.939003,-39.939003,-44.939003,-39.939003,-35.939003,-37.939003,-38.939003,-39.939003,-41.939003,-44.939003,-46.939003,-47.939003,-48.939003,-48.939003,-49.939003,-51.939003,-56.939003,-59.939003,-62.939003,-62.939003,-63.939003,-65.939,-67.939,-68.939,-71.939,-73.939,-75.939,-75.939,-75.939,-77.939,-80.939,-83.939,-82.939,-82.939,-83.939,-84.939,-85.939,-88.939,-89.939,-90.939,-90.939,-92.939,-94.939,-94.939,-94.939,-96.939,-96.939,-95.939,-92.939,-92.939,-97.939,-36.939003,24.060997,69.061,80.061,76.061,72.061,16.060997,-88.939,-101.939,-101.939,-98.939,-97.939,-98.939,-97.939,-97.939,-98.939,-96.939,-94.939,-91.939,-89.939,-87.939,-85.939,-82.939,-80.939,-78.939,-76.939,-73.939,-71.939,-70.939,-62.939003,-58.939003,-58.939003,-59.939003,-58.939003,-52.939003,-54.939003,-59.939003,-55.939003,-54.939003,-57.939003,-53.939003,-50.939003,-49.939003,-54.939003,-60.939003,-57.939003,-55.939003,-56.939003,-53.939003,-51.939003,-53.939003,-57.939003,-62.939003,-66.939,-68.939,-65.939,-64.939,-64.939,-67.939,-66.939,-63.939003,-72.939,-64.939,-39.939003,-30.939003,-22.939003,-13.939003,5.060997,27.060997,35.060997,36.060997,33.060997,35.060997,39.060997,42.060997,41.060997,37.060997,39.060997,45.060997,55.060997,14.060997,-2.939003,49.060997,48.060997,29.060997,53.060997,63.060997,60.060997,53.060997,48.060997,52.060997,55.060997,59.060997,57.060997,57.060997,57.060997,58.060997,58.060997,55.060997,54.060997,53.060997,57.060997,52.060997,39.060997,22.060997,18.060997,49.060997,48.060997,36.060997,48.060997,40.060997,11.060997,-3.939003,-4.939003,21.060997,26.060997,21.060997,11.060997,-14.939003,-57.939003,-66.939,-67.939,-63.939003,-59.939003,-56.939003,-58.939003,-54.939003,-47.939003,-47.939003,-45.939003,-40.939003,-38.939003,-38.939003,-39.939003,-39.939003,-38.939003,-35.939003,-34.939003,-33.939003,-31.939003,-29.939003,-31.939003,-32.939003,-31.939003,-32.939003,-34.939003,-36.939003,-37.939003,-37.939003,-38.939003,-39.939003,-39.939003,-40.939003,-42.939003,-47.939003,-54.939003,-58.939003,-26.939003,-3.939003,10.060997,34.060997,43.060997,16.060997,-3.939003,-23.939003,-67.939,-81.939,-66.939,-67.939,-70.939,-74.939,-56.939003,-24.939003,28.060997,50.060997,42.060997,27.060997,-2.939003,-66.939,-85.939,-85.939,-88.939,-90.939,-90.939,-90.939,-91.939,-92.939,-89.939,-84.939,-85.939,-87.939,-88.939,-89.939,-89.939,-87.939,-89.939,-91.939,-91.939,-91.939,-90.939,-89.939,-90.939,-91.939,-91.939,-90.939,-86.939,-75.939,-55.939003,-70.939,-82.939,-77.939,-71.939,-65.939,-73.939,-75.939,-71.939,-68.939,-64.939,-60.939003,-56.939003,-53.939003,-52.939003,-51.939003,-51.939003,-48.939003,-42.939003,-28.939003,-23.939003,-23.939003,-23.939003,-28.939003,-37.939003,-40.939003,-41.939003,-38.939003,-34.939003,-30.939003,-31.939003,-33.939003,-35.939003,-34.939003,-33.939003,-38.939003,-40.939003,-40.939003,-37.939003,-36.939003,-36.939003,-38.939003,-40.939003,-45.939003,-45.939003,-45.939003,-46.939003,-50.939003,-59.939003,-60.939003,-60.939003,-60.939003,-62.939003,-64.939,-66.939,-68.939,-69.939,-69.939,-70.939,-73.939,-75.939,-76.939,-78.939,-72.939,-60.939003,-51.939003,-44.939003,-32.939003,-16.939003,-3.939003,-26.939003,-33.939003,-24.939003,-20.939003,-16.939003,-12.939003,-7.939003,-2.939003,-1.939003,0.06099701,2.060997,12.060997,22.060997,31.060997,34.060997,35.060997,40.060997,44.060997,49.060997,35.060997,23.060997,20.060997,19.060997,20.060997,18.060997,19.060997,23.060997,23.060997,23.060997,24.060997,25.060997,27.060997,29.060997,30.060997,31.060997,31.060997,31.060997,33.060997,34.060997,35.060997,33.060997,32.060997,31.060997,33.060997,34.060997,36.060997,36.060997,37.060997,37.060997,38.060997,38.060997,35.060997,34.060997,36.060997,35.060997,33.060997,36.060997,37.060997,37.060997,37.060997,37.060997,40.060997,43.060997,46.060997,47.060997,49.060997,48.060997,46.060997,44.060997,43.060997,43.060997,43.060997,37.060997,32.060997,27.060997,31.060997,34.060997,32.060997,32.060997,32.060997,31.060997,30.060997,30.060997,31.060997,32.060997,28.060997,27.060997,25.060997,23.060997,23.060997,24.060997,22.060997,21.060997,20.060997,18.060997,16.060997,13.060997,11.060997,13.060997,12.060997,12.060997,9.060997,8.060997,5.060997,1.060997,0.06099701,0.06099701,1.060997,1.060997,-2.939003,-1.939003,0.06099701,-10.939003,-6.939003,10.060997,26.060997,34.060997,23.060997,18.060997,12.060997,-5.939003,-21.939003,-36.939003,-35.939003,-36.939003,-44.939003,-61.939003,-80.939,-80.939,-56.939003,-6.939003,-9.939003,-21.939003,-38.939003,-51.939003,-61.939003,-59.939003,-58.939003,-56.939003,-57.939003,-56.939003,-52.939003,-51.939003,-50.939003,-52.939003,-51.939003,-48.939003,-48.939003,-49.939003,-49.939003,-52.939003,-54.939003,-51.939003,-51.939003,-54.939003,-57.939003,-58.939003,-57.939003,-60.939003,-62.939003,-59.939003,-54.939003,-49.939003,-50.939003,-51.939003,-50.939003,-48.939003,-46.939003,-52.939003,-49.939003,-36.939003,-19.939003,-11.939003,-26.939003,-28.939003,-23.939003,-19.939003,-15.939003,-12.939003,-9.939003,-11.939003,-24.939003,-22.939003,-14.939003,1.060997,10.060997,12.060997,12.060997,9.060997,1.060997,-2.939003,-4.939003,1.060997,6.060997,10.060997,-0.939003,-7.939003,0.06099701,0.06099701,-6.939003,-38.939003,-53.939003,-51.939003,-49.939003,-47.939003,-46.939003,-44.939003,-42.939003,-44.939003,-44.939003,-43.939003,-45.939003,-44.939003,-36.939003,-31.939003,-29.939003,-28.939003,-29.939003,-31.939003,-33.939003,-34.939003,-38.939003,-38.939003,-32.939003,20.060997,37.060997,17.060997,17.060997,17.060997,14.060997,5.060997,-3.939003,5.060997,12.060997,19.060997,17.060997,15.060997,13.060997,7.060997,-1.939003,-26.939003,-24.939003,5.060997,8.060997,7.060997,7.060997,7.060997,8.060997,5.060997,5.060997,9.060997,6.060997,-10.939003,-57.939003,-77.939,-83.939,-81.939,-80.939,-79.939,-80.939,-79.939,-73.939,-75.939,-81.939,-81.939,-81.939,-81.939,-84.939,-83.939,-74.939,-68.939,-65.939,-75.939,-79.939,-78.939,-78.939,-77.939,-76.939,-73.939,-71.939,-71.939,-70.939,-66.939,-68.939,-69.939,-69.939,-68.939,-66.939,-68.939,-71.939,-73.939,-65.939,-55.939003,-39.939003,-31.939003,-28.939003,-48.939003,-62.939003,-70.939,-69.939,-68.939,-68.939,-68.939,-67.939,-65.939,-66.939,-67.939,-69.939,-72.939,-73.939,-71.939,-68.939,-67.939,-68.939,-72.939,-71.939,-71.939,-70.939,-71.939,-71.939,-69.939,-67.939,-64.939,-69.939,-73.939,-72.939,-71.939,-69.939,-69.939,-71.939,-73.939,-75.939,-78.939,-76.939,-78.939,-76.939,-48.939003,-30.939003,-21.939003,-31.939003,-46.939003,-71.939,-78.939,-76.939,-78.939,-80.939,-82.939,-82.939,-82.939,-82.939,-82.939,-80.939,-82.939,-84.939,-85.939,-85.939,-86.939,-88.939,-91.939,-93.939,-90.939,-89.939,-89.939,-93.939,-97.939,-102.939,-96.939,-86.939,-75.939,-77.939,-93.939,-97.939,-99.939,-99.939,-98.939,-97.939,-99.939,-100.939,-100.939,-69.939,-46.939003,-19.939003,-19.939003,-18.939003,-16.939003,-12.939003,-5.939003,-0.939003,0.06099701,-6.939003,-17.939003,-29.939003,-24.939003,-24.939003,-27.939003,-20.939003,-15.939003,-16.939003,-17.939003,-19.939003,-17.939003,-16.939003,-18.939003,-17.939003,-19.939003,-26.939003,-28.939003,-29.939003,-36.939003,-40.939003,-41.939003,-29.939003,-19.939003,-18.939003,-20.939003,-26.939003,-29.939003,-33.939003,-38.939003,-43.939003,-46.939003,-49.939003,-53.939003,-58.939003,-66.939,-75.939,-86.939,-96.939,-103.939,-103.939,-102.939,-100.939,-98.939,-96.939,-94.939,-93.939,-92.939,-90.939,-88.939,-86.939,-85.939,-84.939,-82.939,-80.939,-79.939,-80.939,-79.939,-77.939,-74.939,-70.939,-67.939,-65.939,-62.939003,-60.939003,-58.939003,-56.939003,-49.939003,-33.939003,-7.939003,-8.939003,-12.939003,-13.939003,-20.939003,-29.939003,-40.939003,-44.939003,-38.939003,-37.939003,-38.939003,-36.939003,-38.939003,-41.939003,-36.939003,-38.939003,-43.939003,-40.939003,-39.939003,-43.939003,-46.939003,-49.939003,-51.939003,-54.939003,-57.939003,-60.939003,-63.939003,-64.939,-67.939,-70.939,-76.939,-80.939,-83.939,-83.939,-84.939,-88.939,-90.939,-91.939,-95.939,-97.939,-100.939,-100.939,-101.939,-101.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-101.939,-100.939,-100.939,-98.939,-96.939,-91.939,-89.939,-91.939,-42.939003,7.060997,49.060997,55.060997,46.060997,49.060997,8.060997,-74.939,-82.939,-81.939,-80.939,-78.939,-78.939,-78.939,-77.939,-77.939,-76.939,-74.939,-69.939,-66.939,-64.939,-63.939003,-62.939003,-62.939003,-61.939003,-60.939003,-58.939003,-59.939003,-59.939003,-51.939003,-48.939003,-48.939003,-50.939003,-50.939003,-44.939003,-49.939003,-56.939003,-53.939003,-53.939003,-58.939003,-55.939003,-54.939003,-55.939003,-60.939003,-66.939,-64.939,-64.939,-65.939,-64.939,-64.939,-68.939,-72.939,-76.939,-82.939,-84.939,-83.939,-81.939,-81.939,-87.939,-88.939,-87.939,-96.939,-82.939,-44.939003,-31.939003,-19.939003,-4.939003,21.060997,49.060997,62.060997,65.061,58.060997,58.060997,60.060997,62.060997,60.060997,55.060997,55.060997,60.060997,69.061,21.060997,-1.939003,47.060997,46.060997,28.060997,50.060997,59.060997,55.060997,46.060997,39.060997,38.060997,40.060997,43.060997,40.060997,38.060997,36.060997,35.060997,33.060997,30.060997,29.060997,28.060997,27.060997,23.060997,13.060997,2.060997,-1.939003,18.060997,17.060997,8.060997,12.060997,6.060997,-8.939003,-8.939003,0.06099701,23.060997,27.060997,22.060997,12.060997,-11.939003,-49.939003,-50.939003,-45.939003,-40.939003,-37.939003,-36.939003,-40.939003,-39.939003,-32.939003,-33.939003,-33.939003,-28.939003,-26.939003,-28.939003,-31.939003,-32.939003,-30.939003,-31.939003,-32.939003,-32.939003,-33.939003,-34.939003,-38.939003,-40.939003,-40.939003,-42.939003,-45.939003,-48.939003,-49.939003,-51.939003,-54.939003,-57.939003,-58.939003,-60.939003,-64.939,-70.939,-76.939,-77.939,-31.939003,-1.939003,13.060997,35.060997,44.060997,19.060997,2.060997,-15.939003,-74.939,-101.939,-97.939,-96.939,-97.939,-102.939,-72.939,-24.939003,50.060997,80.061,64.061,44.060997,4.060997,-78.939,-102.939,-100.939,-99.939,-97.939,-94.939,-95.939,-95.939,-93.939,-88.939,-82.939,-79.939,-77.939,-73.939,-74.939,-74.939,-70.939,-70.939,-72.939,-69.939,-67.939,-65.939,-64.939,-65.939,-65.939,-64.939,-62.939003,-59.939003,-54.939003,-47.939003,-53.939003,-56.939003,-51.939003,-48.939003,-46.939003,-48.939003,-48.939003,-45.939003,-44.939003,-42.939003,-37.939003,-34.939003,-32.939003,-33.939003,-35.939003,-36.939003,-35.939003,-34.939003,-32.939003,-34.939003,-36.939003,-30.939003,-30.939003,-36.939003,-37.939003,-38.939003,-37.939003,-34.939003,-31.939003,-35.939003,-39.939003,-43.939003,-43.939003,-45.939003,-48.939003,-50.939003,-51.939003,-49.939003,-49.939003,-51.939003,-54.939003,-58.939003,-64.939,-65.939,-66.939,-67.939,-71.939,-80.939,-82.939,-84.939,-84.939,-87.939,-90.939,-91.939,-93.939,-94.939,-95.939,-97.939,-100.939,-101.939,-100.939,-102.939,-93.939,-74.939,-56.939003,-38.939003,-19.939003,-14.939003,-16.939003,-28.939003,-30.939003,-21.939003,-18.939003,-16.939003,-13.939003,-9.939003,-4.939003,-0.939003,3.060997,7.060997,15.060997,24.060997,32.060997,34.060997,37.060997,42.060997,45.060997,47.060997,29.060997,16.060997,19.060997,19.060997,20.060997,18.060997,19.060997,24.060997,24.060997,23.060997,23.060997,24.060997,25.060997,28.060997,31.060997,33.060997,31.060997,30.060997,33.060997,34.060997,34.060997,33.060997,32.060997,33.060997,33.060997,33.060997,35.060997,35.060997,36.060997,37.060997,38.060997,39.060997,37.060997,36.060997,37.060997,34.060997,30.060997,36.060997,38.060997,37.060997,37.060997,37.060997,39.060997,42.060997,46.060997,48.060997,50.060997,49.060997,47.060997,45.060997,43.060997,43.060997,43.060997,37.060997,32.060997,28.060997,32.060997,34.060997,31.060997,31.060997,32.060997,31.060997,30.060997,30.060997,30.060997,30.060997,29.060997,27.060997,25.060997,23.060997,23.060997,26.060997,23.060997,21.060997,21.060997,18.060997,16.060997,13.060997,12.060997,14.060997,14.060997,13.060997,11.060997,10.060997,9.060997,4.060997,2.060997,2.060997,2.060997,1.060997,-2.939003,-1.939003,1.060997,-7.939003,-8.939003,-1.939003,19.060997,33.060997,29.060997,21.060997,11.060997,3.060997,-11.939003,-34.939003,-35.939003,-35.939003,-41.939003,-50.939003,-62.939003,-70.939,-61.939003,-33.939003,-35.939003,-39.939003,-45.939003,-47.939003,-48.939003,-48.939003,-48.939003,-48.939003,-48.939003,-48.939003,-46.939003,-47.939003,-49.939003,-52.939003,-52.939003,-51.939003,-54.939003,-56.939003,-57.939003,-60.939003,-62.939003,-61.939003,-62.939003,-66.939,-68.939,-69.939,-70.939,-73.939,-77.939,-66.939,-56.939003,-45.939003,-47.939003,-48.939003,-46.939003,-44.939003,-44.939003,-55.939003,-52.939003,-33.939003,-3.939003,9.060997,-19.939003,-23.939003,-14.939003,-8.939003,-4.939003,-3.939003,-3.939003,-8.939003,-22.939003,-22.939003,-15.939003,-0.939003,6.060997,7.060997,4.060997,-0.939003,-9.939003,-15.939003,-19.939003,-13.939003,-10.939003,-9.939003,-18.939003,-25.939003,-23.939003,-24.939003,-28.939003,-42.939003,-45.939003,-40.939003,-41.939003,-41.939003,-38.939003,-36.939003,-35.939003,-40.939003,-43.939003,-45.939003,-46.939003,-43.939003,-36.939003,-31.939003,-28.939003,-29.939003,-28.939003,-28.939003,-31.939003,-34.939003,-39.939003,-39.939003,-31.939003,18.060997,35.060997,17.060997,19.060997,20.060997,16.060997,7.060997,-1.939003,9.060997,17.060997,23.060997,24.060997,23.060997,22.060997,17.060997,7.060997,-22.939003,-18.939003,21.060997,27.060997,26.060997,22.060997,24.060997,28.060997,27.060997,28.060997,32.060997,27.060997,2.060997,-61.939003,-90.939,-100.939,-95.939,-93.939,-93.939,-93.939,-91.939,-85.939,-89.939,-95.939,-96.939,-96.939,-95.939,-99.939,-97.939,-84.939,-74.939,-69.939,-85.939,-92.939,-90.939,-90.939,-89.939,-89.939,-87.939,-85.939,-84.939,-83.939,-80.939,-80.939,-81.939,-81.939,-80.939,-78.939,-78.939,-80.939,-81.939,-65.939,-44.939003,-17.939003,-9.939003,-11.939003,-47.939003,-69.939,-77.939,-74.939,-72.939,-73.939,-74.939,-73.939,-70.939,-70.939,-71.939,-73.939,-74.939,-75.939,-73.939,-70.939,-68.939,-68.939,-71.939,-71.939,-70.939,-69.939,-70.939,-71.939,-67.939,-65.939,-62.939003,-67.939,-70.939,-69.939,-67.939,-63.939003,-64.939,-65.939,-67.939,-69.939,-72.939,-71.939,-71.939,-69.939,-49.939003,-38.939003,-35.939003,-43.939003,-52.939003,-66.939,-68.939,-65.939,-67.939,-69.939,-72.939,-72.939,-72.939,-72.939,-71.939,-70.939,-71.939,-73.939,-73.939,-71.939,-71.939,-75.939,-77.939,-80.939,-77.939,-76.939,-76.939,-80.939,-85.939,-89.939,-84.939,-76.939,-69.939,-71.939,-83.939,-86.939,-88.939,-88.939,-88.939,-89.939,-88.939,-89.939,-92.939,-66.939,-47.939003,39.060997,41.060997,44.060997,48.060997,58.060997,73.061,64.061,45.060997,8.060997,-15.939003,-31.939003,-20.939003,-18.939003,-22.939003,-16.939003,-13.939003,-13.939003,-13.939003,-14.939003,-15.939003,-15.939003,-14.939003,-22.939003,-29.939003,-38.939003,-39.939003,-36.939003,-32.939003,-34.939003,-39.939003,-41.939003,-40.939003,-33.939003,-35.939003,-39.939003,-38.939003,-38.939003,-40.939003,-46.939003,-51.939003,-56.939003,-68.939,-80.939,-93.939,-100.939,-101.939,-102.939,-103.939,-103.939,-99.939,-93.939,-87.939,-81.939,-77.939,-74.939,-70.939,-65.939,-60.939003,-55.939003,-50.939003,-45.939003,-40.939003,-37.939003,-34.939003,-35.939003,-32.939003,-28.939003,-27.939003,-25.939003,-23.939003,-19.939003,-17.939003,-19.939003,-21.939003,-21.939003,-20.939003,-20.939003,-22.939003,-22.939003,-19.939003,-12.939003,-10.939003,-13.939003,-39.939003,-55.939003,-60.939003,-62.939003,-65.939,-67.939,-73.939,-81.939,-81.939,-83.939,-85.939,-88.939,-92.939,-94.939,-96.939,-98.939,-96.939,-95.939,-95.939,-97.939,-98.939,-99.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-100.939,-100.939,-97.939,-99.939,-101.939,-100.939,-100.939,-101.939,-102.939,-102.939,-103.939,-101.939,-101.939,-102.939,-103.939,-103.939,-103.939,-102.939,-100.939,-98.939,-95.939,-93.939,-91.939,-89.939,-82.939,-80.939,-82.939,-50.939003,-17.939003,11.060997,8.060997,-6.939003,-8.939003,-23.939003,-49.939003,-45.939003,-40.939003,-41.939003,-39.939003,-36.939003,-42.939003,-42.939003,-40.939003,-41.939003,-39.939003,-34.939003,-32.939003,-31.939003,-35.939003,-39.939003,-45.939003,-49.939003,-53.939003,-57.939003,-62.939003,-66.939,-66.939,-67.939,-70.939,-72.939,-75.939,-75.939,-82.939,-90.939,-88.939,-89.939,-94.939,-95.939,-96.939,-97.939,-97.939,-97.939,-97.939,-97.939,-98.939,-99.939,-99.939,-100.939,-99.939,-99.939,-100.939,-99.939,-97.939,-92.939,-90.939,-92.939,-95.939,-98.939,-101.939,-82.939,-41.939003,-34.939003,-24.939003,0.06099701,20.060997,40.060997,59.060997,64.061,53.060997,52.060997,53.060997,55.060997,55.060997,56.060997,54.060997,56.060997,64.061,22.060997,-2.939003,25.060997,27.060997,21.060997,39.060997,46.060997,43.060997,29.060997,18.060997,18.060997,16.060997,13.060997,5.060997,-0.939003,-5.939003,-9.939003,-14.939003,-21.939003,-23.939003,-25.939003,-31.939003,-35.939003,-37.939003,-36.939003,-37.939003,-40.939003,-41.939003,-42.939003,-52.939003,-54.939003,-47.939003,-14.939003,15.060997,28.060997,29.060997,23.060997,17.060997,-5.939003,-45.939003,-42.939003,-34.939003,-29.939003,-32.939003,-39.939003,-47.939003,-52.939003,-53.939003,-58.939003,-60.939003,-59.939003,-61.939003,-67.939,-72.939,-75.939,-75.939,-78.939,-81.939,-84.939,-87.939,-89.939,-92.939,-93.939,-93.939,-93.939,-94.939,-95.939,-96.939,-96.939,-97.939,-96.939,-95.939,-97.939,-99.939,-100.939,-101.939,-95.939,-34.939003,2.060997,15.060997,36.060997,45.060997,20.060997,6.060997,-8.939003,-71.939,-103.939,-102.939,-102.939,-102.939,-103.939,-75.939,-31.939003,41.060997,69.061,51.060997,33.060997,-2.939003,-80.939,-99.939,-95.939,-90.939,-84.939,-76.939,-78.939,-78.939,-73.939,-66.939,-59.939003,-54.939003,-50.939003,-44.939003,-45.939003,-45.939003,-37.939003,-35.939003,-35.939003,-32.939003,-27.939003,-22.939003,-21.939003,-22.939003,-24.939003,-22.939003,-18.939003,-19.939003,-24.939003,-33.939003,-28.939003,-22.939003,-21.939003,-23.939003,-28.939003,-20.939003,-18.939003,-22.939003,-28.939003,-31.939003,-28.939003,-29.939003,-35.939003,-42.939003,-46.939003,-47.939003,-57.939003,-54.939003,-20.939003,2.060997,21.060997,34.060997,13.060997,-39.939003,-68.939,-86.939,-84.939,-85.939,-87.939,-90.939,-92.939,-93.939,-96.939,-98.939,-96.939,-96.939,-97.939,-98.939,-98.939,-98.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-101.939,-101.939,-101.939,-101.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-92.939,-71.939,-50.939003,-27.939003,-6.939003,-13.939003,-31.939003,-31.939003,-27.939003,-20.939003,-14.939003,-11.939003,-13.939003,-11.939003,-7.939003,-1.939003,4.060997,11.060997,19.060997,27.060997,33.060997,35.060997,37.060997,42.060997,42.060997,38.060997,23.060997,13.060997,18.060997,20.060997,21.060997,20.060997,20.060997,23.060997,22.060997,22.060997,23.060997,24.060997,24.060997,27.060997,31.060997,35.060997,32.060997,31.060997,34.060997,33.060997,31.060997,30.060997,31.060997,36.060997,33.060997,31.060997,32.060997,33.060997,33.060997,35.060997,37.060997,38.060997,39.060997,39.060997,38.060997,34.060997,29.060997,36.060997,38.060997,37.060997,38.060997,38.060997,37.060997,40.060997,45.060997,50.060997,51.060997,48.060997,47.060997,46.060997,41.060997,41.060997,42.060997,35.060997,30.060997,26.060997,31.060997,34.060997,32.060997,32.060997,32.060997,31.060997,30.060997,29.060997,27.060997,27.060997,30.060997,29.060997,27.060997,23.060997,23.060997,28.060997,24.060997,21.060997,22.060997,19.060997,16.060997,15.060997,15.060997,14.060997,15.060997,14.060997,12.060997,11.060997,11.060997,6.060997,4.060997,3.060997,1.060997,-0.939003,-2.939003,-2.939003,-1.939003,-2.939003,-6.939003,-13.939003,8.060997,27.060997,34.060997,24.060997,8.060997,12.060997,-1.939003,-32.939003,-36.939003,-37.939003,-39.939003,-41.939003,-45.939003,-65.939,-63.939003,-38.939003,-28.939003,-25.939003,-41.939003,-52.939003,-59.939003,-66.939,-71.939,-73.939,-73.939,-74.939,-78.939,-82.939,-86.939,-89.939,-91.939,-94.939,-97.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-100.939,-98.939,-72.939,-54.939003,-42.939003,-43.939003,-43.939003,-41.939003,-39.939003,-41.939003,-54.939003,-51.939003,-32.939003,-4.939003,5.060997,-26.939003,-32.939003,-27.939003,-25.939003,-27.939003,-30.939003,-34.939003,-38.939003,-43.939003,-44.939003,-42.939003,-44.939003,-44.939003,-43.939003,-45.939003,-48.939003,-54.939003,-58.939003,-60.939003,-57.939003,-57.939003,-62.939003,-61.939003,-61.939003,-64.939,-61.939003,-55.939003,-43.939003,-34.939003,-29.939003,-33.939003,-34.939003,-29.939003,-29.939003,-30.939003,-37.939003,-43.939003,-49.939003,-47.939003,-42.939003,-37.939003,-32.939003,-29.939003,-28.939003,-27.939003,-26.939003,-30.939003,-34.939003,-38.939003,-39.939003,-38.939003,-27.939003,-25.939003,-31.939003,-25.939003,-20.939003,-24.939003,-27.939003,-31.939003,-27.939003,-25.939003,-22.939003,-18.939003,-15.939003,-15.939003,-13.939003,-12.939003,-33.939003,-29.939003,-1.939003,0.06099701,-0.939003,-2.939003,-0.939003,2.060997,6.060997,11.060997,16.060997,12.060997,-8.939003,-62.939003,-87.939,-97.939,-92.939,-91.939,-94.939,-91.939,-89.939,-90.939,-94.939,-98.939,-99.939,-99.939,-100.939,-102.939,-99.939,-88.939,-77.939,-71.939,-91.939,-102.939,-102.939,-101.939,-100.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-68.939,-24.939003,34.060997,43.060997,25.060997,-48.939003,-90.939,-100.939,-99.939,-98.939,-98.939,-98.939,-98.939,-98.939,-99.939,-100.939,-100.939,-100.939,-100.939,-98.939,-96.939,-93.939,-91.939,-92.939,-94.939,-95.939,-94.939,-96.939,-99.939,-94.939,-90.939,-89.939,-91.939,-91.939,-91.939,-88.939,-84.939,-83.939,-83.939,-82.939,-82.939,-83.939,-86.939,-82.939,-72.939,-40.939003,-22.939003,-17.939003,-34.939003,-52.939003,-71.939,-72.939,-68.939,-69.939,-70.939,-71.939,-71.939,-71.939,-70.939,-70.939,-69.939,-67.939,-66.939,-64.939,-60.939003,-58.939003,-62.939003,-62.939003,-62.939003,-62.939003,-62.939003,-61.939003,-64.939,-66.939,-65.939,-63.939003,-61.939003,-57.939003,-59.939003,-66.939,-64.939,-62.939003,-61.939003,-63.939003,-66.939,-66.939,-67.939,-70.939,-61.939003,-55.939003,50.060997,52.060997,54.060997,48.060997,40.060997,28.060997,9.060997,-8.939003,-16.939003,-21.939003,-24.939003,-22.939003,-19.939003,-14.939003,-12.939003,-11.939003,-11.939003,-11.939003,-12.939003,-17.939003,-21.939003,-25.939003,-30.939003,-34.939003,-39.939003,-39.939003,-38.939003,-38.939003,-39.939003,-42.939003,-44.939003,-45.939003,-48.939003,-51.939003,-54.939003,-51.939003,-54.939003,-62.939003,-67.939,-70.939,-73.939,-82.939,-92.939,-93.939,-93.939,-90.939,-82.939,-75.939,-70.939,-64.939,-55.939003,-52.939003,-48.939003,-45.939003,-44.939003,-44.939003,-44.939003,-45.939003,-44.939003,-45.939003,-45.939003,-46.939003,-47.939003,-49.939003,-51.939003,-51.939003,-52.939003,-55.939003,-57.939003,-58.939003,-56.939003,-55.939003,-58.939003,-59.939003,-59.939003,-62.939003,-40.939003,7.060997,26.060997,35.060997,36.060997,34.060997,26.060997,-45.939003,-82.939,-85.939,-86.939,-88.939,-89.939,-92.939,-95.939,-94.939,-95.939,-97.939,-99.939,-100.939,-101.939,-100.939,-100.939,-98.939,-99.939,-101.939,-102.939,-102.939,-102.939,-102.939,-103.939,-101.939,-99.939,-97.939,-96.939,-95.939,-94.939,-94.939,-93.939,-93.939,-91.939,-86.939,-85.939,-84.939,-81.939,-79.939,-78.939,-78.939,-77.939,-78.939,-73.939,-68.939,-64.939,-63.939003,-63.939003,-65.939,-64.939,-61.939003,-57.939003,-54.939003,-55.939003,-58.939003,-62.939003,-56.939003,-54.939003,-56.939003,-41.939003,-23.939003,-3.939003,-1.939003,-5.939003,-1.939003,-16.939003,-50.939003,-60.939003,-63.939003,-59.939003,-60.939003,-63.939003,-70.939,-71.939,-68.939,-72.939,-74.939,-71.939,-71.939,-70.939,-72.939,-74.939,-77.939,-79.939,-81.939,-83.939,-85.939,-87.939,-88.939,-88.939,-89.939,-90.939,-91.939,-90.939,-93.939,-98.939,-96.939,-97.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-102.939,-102.939,-102.939,-102.939,-101.939,-99.939,-101.939,-100.939,-97.939,-94.939,-92.939,-89.939,-88.939,-88.939,-92.939,-77.939,-42.939003,-33.939003,-25.939003,-10.939003,1.060997,12.060997,25.060997,25.060997,13.060997,11.060997,10.060997,7.060997,3.060997,0.06099701,-0.939003,0.06099701,2.060997,-10.939003,-18.939003,-10.939003,-12.939003,-16.939003,-12.939003,-11.939003,-13.939003,-15.939003,-15.939003,-12.939003,-13.939003,-15.939003,-15.939003,-14.939003,-13.939003,-12.939003,-12.939003,-14.939003,-12.939003,-9.939003,-12.939003,-12.939003,-9.939003,-15.939003,-15.939003,0.06099701,3.060997,2.060997,1.060997,-0.939003,-5.939003,-1.939003,5.060997,22.060997,27.060997,28.060997,24.060997,-2.939003,-49.939003,-65.939,-72.939,-70.939,-72.939,-75.939,-79.939,-81.939,-82.939,-84.939,-85.939,-85.939,-86.939,-89.939,-91.939,-92.939,-93.939,-94.939,-96.939,-96.939,-97.939,-99.939,-99.939,-100.939,-99.939,-99.939,-98.939,-99.939,-100.939,-102.939,-99.939,-97.939,-98.939,-95.939,-94.939,-96.939,-97.939,-92.939,-33.939003,2.060997,16.060997,37.060997,46.060997,21.060997,4.060997,-11.939003,-59.939003,-78.939,-68.939,-64.939,-60.939003,-53.939003,-45.939003,-34.939003,-7.939003,0.06099701,-10.939003,-18.939003,-27.939003,-39.939003,-43.939003,-44.939003,-43.939003,-41.939003,-36.939003,-39.939003,-42.939003,-41.939003,-40.939003,-41.939003,-42.939003,-42.939003,-40.939003,-44.939003,-47.939003,-42.939003,-43.939003,-45.939003,-45.939003,-44.939003,-45.939003,-46.939003,-47.939003,-51.939003,-55.939003,-57.939003,-59.939003,-55.939003,-43.939003,-53.939003,-61.939003,-61.939003,-59.939003,-57.939003,-61.939003,-64.939,-67.939,-68.939,-69.939,-69.939,-71.939,-73.939,-77.939,-78.939,-77.939,-83.939,-69.939,-7.939003,34.060997,65.061,68.061,33.060997,-37.939003,-74.939,-98.939,-97.939,-97.939,-99.939,-98.939,-98.939,-98.939,-100.939,-101.939,-98.939,-98.939,-100.939,-100.939,-98.939,-95.939,-93.939,-92.939,-88.939,-86.939,-85.939,-82.939,-82.939,-85.939,-83.939,-80.939,-79.939,-76.939,-72.939,-71.939,-69.939,-66.939,-62.939003,-60.939003,-57.939003,-55.939003,-54.939003,-54.939003,-52.939003,-48.939003,-38.939003,-27.939003,-18.939003,-21.939003,-27.939003,-23.939003,-20.939003,-18.939003,-15.939003,-12.939003,-8.939003,-4.939003,-0.939003,-0.939003,2.060997,8.060997,21.060997,31.060997,34.060997,34.060997,35.060997,44.060997,40.060997,26.060997,20.060997,18.060997,20.060997,19.060997,18.060997,20.060997,22.060997,25.060997,24.060997,24.060997,25.060997,26.060997,28.060997,29.060997,30.060997,32.060997,30.060997,29.060997,32.060997,32.060997,32.060997,31.060997,31.060997,34.060997,35.060997,34.060997,32.060997,32.060997,34.060997,36.060997,38.060997,38.060997,36.060997,36.060997,38.060997,37.060997,35.060997,37.060997,36.060997,34.060997,37.060997,40.060997,40.060997,42.060997,45.060997,48.060997,49.060997,47.060997,45.060997,44.060997,41.060997,42.060997,42.060997,36.060997,33.060997,33.060997,35.060997,36.060997,34.060997,32.060997,30.060997,31.060997,30.060997,28.060997,26.060997,26.060997,28.060997,27.060997,27.060997,25.060997,25.060997,28.060997,26.060997,23.060997,21.060997,18.060997,17.060997,16.060997,16.060997,16.060997,15.060997,13.060997,13.060997,12.060997,12.060997,8.060997,7.060997,6.060997,4.060997,2.060997,0.06099701,-1.939003,-3.939003,-2.939003,-4.939003,-9.939003,1.060997,15.060997,31.060997,29.060997,18.060997,18.060997,3.060997,-23.939003,-30.939003,-34.939003,-35.939003,-35.939003,-38.939003,-61.939003,-58.939003,-28.939003,4.060997,14.060997,-34.939003,-64.939,-84.939,-88.939,-90.939,-91.939,-91.939,-92.939,-93.939,-95.939,-96.939,-96.939,-96.939,-97.939,-96.939,-95.939,-93.939,-89.939,-86.939,-86.939,-84.939,-83.939,-79.939,-77.939,-78.939,-77.939,-75.939,-61.939003,-51.939003,-44.939003,-46.939003,-47.939003,-46.939003,-45.939003,-44.939003,-48.939003,-49.939003,-47.939003,-39.939003,-36.939003,-42.939003,-45.939003,-48.939003,-48.939003,-49.939003,-50.939003,-47.939003,-44.939003,-45.939003,-48.939003,-51.939003,-46.939003,-42.939003,-38.939003,-40.939003,-43.939003,-44.939003,-42.939003,-40.939003,-36.939003,-33.939003,-32.939003,-31.939003,-31.939003,-25.939003,-23.939003,-23.939003,-41.939003,-48.939003,-44.939003,-44.939003,-43.939003,-42.939003,-43.939003,-44.939003,-41.939003,-41.939003,-42.939003,-42.939003,-41.939003,-37.939003,-32.939003,-28.939003,-29.939003,-28.939003,-26.939003,-28.939003,-32.939003,-38.939003,-38.939003,-35.939003,-20.939003,-19.939003,-31.939003,-28.939003,-26.939003,-26.939003,-29.939003,-34.939003,-35.939003,-35.939003,-31.939003,-28.939003,-26.939003,-28.939003,-28.939003,-27.939003,-36.939003,-37.939003,-29.939003,-28.939003,-29.939003,-34.939003,-36.939003,-35.939003,-30.939003,-24.939003,-19.939003,-20.939003,-29.939003,-52.939003,-63.939003,-67.939,-62.939003,-61.939003,-65.939,-66.939,-66.939,-68.939,-69.939,-70.939,-70.939,-70.939,-71.939,-72.939,-72.939,-69.939,-65.939,-61.939003,-67.939,-73.939,-78.939,-77.939,-76.939,-74.939,-75.939,-79.939,-79.939,-79.939,-79.939,-80.939,-80.939,-79.939,-78.939,-78.939,-80.939,-83.939,-86.939,-63.939003,-34.939003,3.060997,13.060997,7.060997,-49.939003,-80.939,-85.939,-87.939,-88.939,-88.939,-88.939,-88.939,-89.939,-90.939,-91.939,-92.939,-93.939,-92.939,-90.939,-89.939,-87.939,-86.939,-87.939,-89.939,-90.939,-93.939,-98.939,-103.939,-100.939,-98.939,-97.939,-98.939,-99.939,-99.939,-97.939,-96.939,-95.939,-94.939,-94.939,-93.939,-94.939,-96.939,-90.939,-76.939,-27.939003,0.06099701,9.060997,-23.939003,-56.939003,-83.939,-89.939,-88.939,-88.939,-89.939,-89.939,-89.939,-89.939,-89.939,-89.939,-88.939,-87.939,-86.939,-85.939,-84.939,-83.939,-85.939,-85.939,-84.939,-85.939,-85.939,-84.939,-85.939,-86.939,-85.939,-80.939,-74.939,-70.939,-73.939,-81.939,-82.939,-81.939,-82.939,-82.939,-82.939,-83.939,-81.939,-77.939,-62.939003,-50.939003,53.060997,50.060997,45.060997,28.060997,7.060997,-18.939003,-37.939003,-46.939003,-33.939003,-24.939003,-19.939003,-22.939003,-19.939003,-11.939003,-11.939003,-11.939003,-11.939003,-12.939003,-14.939003,-22.939003,-29.939003,-35.939003,-38.939003,-39.939003,-40.939003,-40.939003,-40.939003,-43.939003,-45.939003,-46.939003,-46.939003,-49.939003,-59.939003,-65.939,-68.939,-65.939,-71.939,-84.939,-85.939,-84.939,-86.939,-92.939,-98.939,-91.939,-86.939,-82.939,-70.939,-59.939003,-52.939003,-46.939003,-38.939003,-38.939003,-37.939003,-36.939003,-37.939003,-39.939003,-43.939003,-47.939003,-49.939003,-54.939003,-58.939003,-62.939003,-66.939,-70.939,-73.939,-75.939,-78.939,-84.939,-87.939,-90.939,-90.939,-90.939,-93.939,-93.939,-92.939,-97.939,-57.939003,27.060997,62.060997,79.061,73.061,69.061,57.060997,-47.939003,-101.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-100.939,-100.939,-101.939,-100.939,-99.939,-98.939,-95.939,-92.939,-92.939,-92.939,-94.939,-93.939,-92.939,-92.939,-92.939,-93.939,-90.939,-86.939,-83.939,-81.939,-80.939,-78.939,-78.939,-78.939,-77.939,-74.939,-69.939,-66.939,-65.939,-63.939003,-61.939003,-58.939003,-57.939003,-56.939003,-59.939003,-55.939003,-49.939003,-42.939003,-40.939003,-40.939003,-43.939003,-44.939003,-41.939003,-38.939003,-36.939003,-40.939003,-46.939003,-51.939003,-48.939003,-47.939003,-49.939003,-40.939003,-26.939003,3.060997,11.060997,10.060997,22.060997,1.060997,-54.939003,-77.939,-87.939,-81.939,-84.939,-91.939,-96.939,-97.939,-95.939,-99.939,-102.939,-102.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-102.939,-99.939,-100.939,-101.939,-98.939,-98.939,-100.939,-99.939,-97.939,-96.939,-95.939,-95.939,-94.939,-93.939,-92.939,-91.939,-90.939,-89.939,-86.939,-84.939,-86.939,-84.939,-81.939,-78.939,-76.939,-72.939,-68.939,-65.939,-69.939,-60.939003,-37.939003,-30.939003,-25.939003,-19.939003,-14.939003,-11.939003,-5.939003,-8.939003,-19.939003,-19.939003,-19.939003,-24.939003,-29.939003,-34.939003,-32.939003,-30.939003,-27.939003,-27.939003,-27.939003,-24.939003,-27.939003,-32.939003,-34.939003,-35.939003,-34.939003,-28.939003,-22.939003,-18.939003,-17.939003,-16.939003,-11.939003,-7.939003,-3.939003,1.060997,4.060997,6.060997,11.060997,17.060997,17.060997,19.060997,23.060997,9.060997,8.060997,41.060997,47.060997,44.060997,51.060997,47.060997,33.060997,8.060997,-3.939003,18.060997,26.060997,30.060997,27.060997,-0.939003,-52.939003,-82.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-101.939,-101.939,-100.939,-98.939,-97.939,-97.939,-96.939,-95.939,-92.939,-91.939,-91.939,-89.939,-88.939,-87.939,-86.939,-84.939,-84.939,-85.939,-85.939,-80.939,-79.939,-81.939,-76.939,-72.939,-74.939,-77.939,-75.939,-29.939003,0.06099701,16.060997,37.060997,45.060997,22.060997,3.060997,-14.939003,-51.939003,-60.939003,-40.939003,-34.939003,-29.939003,-23.939003,-28.939003,-35.939003,-35.939003,-36.939003,-40.939003,-42.939003,-38.939003,-19.939003,-16.939003,-21.939003,-24.939003,-26.939003,-24.939003,-28.939003,-31.939003,-33.939003,-37.939003,-42.939003,-47.939003,-50.939003,-52.939003,-57.939003,-61.939003,-59.939003,-61.939003,-65.939,-66.939,-69.939,-73.939,-75.939,-76.939,-80.939,-87.939,-93.939,-96.939,-82.939,-53.939003,-75.939,-96.939,-96.939,-88.939,-81.939,-96.939,-103.939,-102.939,-101.939,-100.939,-101.939,-101.939,-99.939,-99.939,-98.939,-95.939,-97.939,-78.939,-4.939003,43.060997,77.061,70.061,32.060997,-38.939003,-71.939,-92.939,-91.939,-90.939,-90.939,-87.939,-84.939,-83.939,-84.939,-84.939,-81.939,-81.939,-83.939,-82.939,-79.939,-74.939,-71.939,-69.939,-63.939003,-61.939003,-60.939003,-55.939003,-55.939003,-60.939003,-57.939003,-55.939003,-54.939003,-49.939003,-44.939003,-44.939003,-41.939003,-38.939003,-33.939003,-30.939003,-27.939003,-25.939003,-25.939003,-26.939003,-30.939003,-35.939003,-29.939003,-24.939003,-30.939003,-29.939003,-24.939003,-16.939003,-13.939003,-14.939003,-13.939003,-11.939003,-5.939003,0.06099701,5.060997,1.060997,2.060997,8.060997,23.060997,35.060997,35.060997,35.060997,37.060997,44.060997,37.060997,17.060997,18.060997,21.060997,21.060997,19.060997,15.060997,20.060997,24.060997,26.060997,25.060997,25.060997,26.060997,28.060997,30.060997,29.060997,29.060997,30.060997,28.060997,29.060997,31.060997,32.060997,33.060997,32.060997,31.060997,32.060997,36.060997,37.060997,32.060997,32.060997,35.060997,37.060997,38.060997,39.060997,35.060997,34.060997,38.060997,40.060997,39.060997,39.060997,36.060997,32.060997,36.060997,40.060997,42.060997,43.060997,44.060997,47.060997,48.060997,48.060997,46.060997,44.060997,42.060997,42.060997,42.060997,37.060997,35.060997,37.060997,37.060997,36.060997,35.060997,32.060997,28.060997,31.060997,31.060997,28.060997,27.060997,27.060997,26.060997,26.060997,27.060997,27.060997,27.060997,27.060997,27.060997,25.060997,21.060997,19.060997,20.060997,18.060997,16.060997,16.060997,15.060997,14.060997,14.060997,14.060997,13.060997,10.060997,9.060997,9.060997,6.060997,4.060997,3.060997,0.06099701,-5.939003,-2.939003,-2.939003,-5.939003,-5.939003,2.060997,26.060997,32.060997,29.060997,23.060997,8.060997,-13.939003,-25.939003,-32.939003,-32.939003,-33.939003,-36.939003,-57.939003,-55.939003,-30.939003,13.060997,28.060997,-32.939003,-69.939,-92.939,-93.939,-92.939,-92.939,-91.939,-91.939,-91.939,-90.939,-88.939,-87.939,-84.939,-83.939,-82.939,-79.939,-75.939,-70.939,-65.939,-66.939,-65.939,-63.939003,-58.939003,-54.939003,-56.939003,-56.939003,-56.939003,-53.939003,-50.939003,-47.939003,-49.939003,-50.939003,-50.939003,-49.939003,-47.939003,-46.939003,-48.939003,-54.939003,-55.939003,-54.939003,-47.939003,-48.939003,-52.939003,-53.939003,-53.939003,-52.939003,-44.939003,-38.939003,-42.939003,-44.939003,-44.939003,-32.939003,-25.939003,-20.939003,-22.939003,-24.939003,-23.939003,-21.939003,-17.939003,-13.939003,-8.939003,-1.939003,-4.939003,-5.939003,5.060997,6.060997,-0.939003,-40.939003,-60.939003,-57.939003,-53.939003,-50.939003,-52.939003,-53.939003,-53.939003,-44.939003,-39.939003,-38.939003,-39.939003,-40.939003,-35.939003,-31.939003,-28.939003,-29.939003,-29.939003,-27.939003,-27.939003,-29.939003,-37.939003,-35.939003,-27.939003,0.06099701,4.060997,-17.939003,-16.939003,-14.939003,-14.939003,-19.939003,-26.939003,-26.939003,-24.939003,-20.939003,-18.939003,-18.939003,-20.939003,-24.939003,-30.939003,-34.939003,-35.939003,-31.939003,-31.939003,-33.939003,-39.939003,-42.939003,-43.939003,-40.939003,-35.939003,-29.939003,-31.939003,-38.939003,-50.939003,-54.939003,-53.939003,-48.939003,-48.939003,-52.939003,-53.939003,-54.939003,-55.939003,-55.939003,-54.939003,-55.939003,-56.939003,-55.939003,-57.939003,-58.939003,-57.939003,-57.939003,-56.939003,-55.939003,-57.939003,-62.939003,-62.939003,-61.939003,-57.939003,-59.939003,-64.939,-63.939003,-62.939003,-62.939003,-63.939003,-63.939003,-63.939003,-61.939003,-59.939003,-63.939003,-68.939,-72.939,-59.939003,-43.939003,-27.939003,-17.939003,-13.939003,-48.939003,-67.939,-69.939,-73.939,-75.939,-74.939,-74.939,-73.939,-74.939,-75.939,-77.939,-78.939,-80.939,-77.939,-76.939,-75.939,-75.939,-75.939,-77.939,-78.939,-80.939,-84.939,-89.939,-93.939,-93.939,-92.939,-92.939,-94.939,-95.939,-95.939,-94.939,-94.939,-93.939,-93.939,-93.939,-94.939,-94.939,-96.939,-90.939,-76.939,-23.939003,6.060997,16.060997,-21.939003,-59.939003,-87.939,-97.939,-97.939,-97.939,-97.939,-98.939,-98.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-100.939,-101.939,-100.939,-99.939,-101.939,-101.939,-102.939,-102.939,-94.939,-84.939,-80.939,-84.939,-96.939,-99.939,-100.939,-101.939,-99.939,-97.939,-99.939,-95.939,-87.939,-60.939003,-40.939003,52.060997,33.060997,6.060997,-21.939003,-36.939003,-40.939003,-32.939003,-25.939003,-20.939003,-19.939003,-21.939003,-18.939003,-18.939003,-20.939003,-17.939003,-16.939003,-19.939003,-24.939003,-29.939003,-33.939003,-37.939003,-40.939003,-39.939003,-40.939003,-43.939003,-44.939003,-44.939003,-47.939003,-49.939003,-51.939003,-51.939003,-54.939003,-64.939,-71.939,-77.939,-84.939,-91.939,-98.939,-91.939,-85.939,-84.939,-88.939,-93.939,-90.939,-89.939,-89.939,-82.939,-78.939,-77.939,-77.939,-77.939,-80.939,-82.939,-83.939,-85.939,-86.939,-88.939,-89.939,-88.939,-91.939,-93.939,-94.939,-95.939,-95.939,-95.939,-94.939,-93.939,-95.939,-96.939,-95.939,-95.939,-95.939,-96.939,-96.939,-95.939,-96.939,-61.939003,10.060997,55.060997,81.061,68.061,69.061,66.061,-39.939003,-95.939,-102.939,-103.939,-103.939,-102.939,-100.939,-97.939,-96.939,-95.939,-94.939,-88.939,-84.939,-83.939,-80.939,-77.939,-74.939,-71.939,-66.939,-63.939003,-61.939003,-58.939003,-60.939003,-62.939003,-55.939003,-50.939003,-47.939003,-45.939003,-43.939003,-41.939003,-43.939003,-45.939003,-44.939003,-42.939003,-38.939003,-41.939003,-46.939003,-52.939003,-52.939003,-53.939003,-50.939003,-51.939003,-59.939003,-66.939,-69.939,-67.939,-66.939,-67.939,-72.939,-76.939,-80.939,-81.939,-82.939,-87.939,-88.939,-86.939,-86.939,-88.939,-90.939,-65.939,-23.939003,55.060997,66.061,52.060997,69.061,32.060997,-57.939003,-85.939,-99.939,-93.939,-95.939,-100.939,-99.939,-99.939,-101.939,-100.939,-99.939,-100.939,-99.939,-98.939,-101.939,-102.939,-103.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-99.939,-98.939,-101.939,-103.939,-101.939,-101.939,-100.939,-95.939,-92.939,-91.939,-87.939,-82.939,-81.939,-77.939,-71.939,-69.939,-65.939,-59.939003,-54.939003,-50.939003,-46.939003,-41.939003,-37.939003,-35.939003,-32.939003,-26.939003,-23.939003,-21.939003,-22.939003,-21.939003,-20.939003,-17.939003,-18.939003,-21.939003,-22.939003,-23.939003,-22.939003,-21.939003,-19.939003,-20.939003,-23.939003,-24.939003,-19.939003,-13.939003,-9.939003,-6.939003,-4.939003,-0.939003,11.060997,29.060997,-2.939003,-17.939003,14.060997,18.060997,11.060997,25.060997,34.060997,36.060997,36.060997,36.060997,40.060997,42.060997,44.060997,44.060997,44.060997,43.060997,44.060997,45.060997,47.060997,48.060997,48.060997,49.060997,51.060997,52.060997,25.060997,17.060997,60.060997,65.061,56.060997,59.060997,54.060997,42.060997,9.060997,-6.939003,25.060997,30.060997,25.060997,24.060997,0.06099701,-47.939003,-81.939,-103.939,-102.939,-102.939,-103.939,-102.939,-101.939,-100.939,-97.939,-95.939,-94.939,-92.939,-90.939,-81.939,-77.939,-76.939,-72.939,-68.939,-63.939003,-59.939003,-55.939003,-51.939003,-47.939003,-42.939003,-41.939003,-39.939003,-38.939003,-34.939003,-30.939003,-25.939003,-24.939003,-26.939003,-24.939003,-23.939003,-23.939003,-26.939003,-29.939003,-19.939003,-6.939003,11.060997,34.060997,44.060997,21.060997,4.060997,-13.939003,-55.939003,-64.939,-39.939003,-36.939003,-41.939003,-51.939003,-46.939003,-33.939003,-1.939003,14.060997,14.060997,5.060997,-13.939003,-58.939003,-71.939,-73.939,-78.939,-81.939,-82.939,-84.939,-86.939,-86.939,-87.939,-89.939,-90.939,-91.939,-91.939,-92.939,-93.939,-92.939,-93.939,-94.939,-95.939,-95.939,-96.939,-96.939,-97.939,-98.939,-99.939,-101.939,-101.939,-87.939,-57.939003,-80.939,-101.939,-101.939,-90.939,-77.939,-94.939,-102.939,-100.939,-100.939,-99.939,-97.939,-92.939,-87.939,-86.939,-84.939,-81.939,-86.939,-74.939,-24.939003,4.060997,19.060997,5.060997,-16.939003,-43.939003,-51.939003,-53.939003,-51.939003,-49.939003,-47.939003,-43.939003,-39.939003,-35.939003,-33.939003,-32.939003,-29.939003,-30.939003,-31.939003,-30.939003,-28.939003,-25.939003,-23.939003,-20.939003,-21.939003,-23.939003,-25.939003,-23.939003,-24.939003,-25.939003,-26.939003,-26.939003,-29.939003,-30.939003,-31.939003,-36.939003,-40.939003,-40.939003,-42.939003,-43.939003,-46.939003,-49.939003,-53.939003,-60.939003,-59.939003,-52.939003,-27.939003,-14.939003,-35.939003,-35.939003,-25.939003,-15.939003,-10.939003,-9.939003,-8.939003,-7.939003,-5.939003,-1.939003,4.060997,3.060997,6.060997,13.060997,27.060997,36.060997,33.060997,40.060997,48.060997,42.060997,31.060997,15.060997,16.060997,18.060997,18.060997,17.060997,16.060997,20.060997,22.060997,23.060997,24.060997,24.060997,24.060997,26.060997,29.060997,27.060997,28.060997,30.060997,30.060997,30.060997,33.060997,34.060997,35.060997,34.060997,33.060997,32.060997,35.060997,36.060997,31.060997,33.060997,37.060997,36.060997,39.060997,44.060997,39.060997,36.060997,40.060997,41.060997,40.060997,40.060997,39.060997,37.060997,37.060997,37.060997,37.060997,38.060997,42.060997,46.060997,50.060997,53.060997,51.060997,48.060997,46.060997,44.060997,42.060997,37.060997,34.060997,33.060997,34.060997,34.060997,32.060997,30.060997,28.060997,32.060997,32.060997,30.060997,29.060997,28.060997,28.060997,28.060997,27.060997,26.060997,25.060997,23.060997,25.060997,26.060997,24.060997,24.060997,25.060997,21.060997,18.060997,14.060997,15.060997,16.060997,17.060997,15.060997,13.060997,10.060997,8.060997,7.060997,6.060997,6.060997,6.060997,1.060997,-4.939003,-4.939003,-4.939003,-4.939003,-11.939003,-8.939003,18.060997,32.060997,38.060997,28.060997,13.060997,-5.939003,-22.939003,-33.939003,-30.939003,-35.939003,-44.939003,-50.939003,-57.939003,-63.939003,-39.939003,-24.939003,-42.939003,-51.939003,-57.939003,-58.939003,-57.939003,-55.939003,-53.939003,-51.939003,-49.939003,-49.939003,-48.939003,-45.939003,-43.939003,-42.939003,-43.939003,-44.939003,-43.939003,-41.939003,-39.939003,-44.939003,-48.939003,-50.939003,-47.939003,-45.939003,-47.939003,-52.939003,-57.939003,-54.939003,-52.939003,-51.939003,-50.939003,-51.939003,-49.939003,-48.939003,-48.939003,-56.939003,-53.939003,-40.939003,-14.939003,-3.939003,-29.939003,-27.939003,-14.939003,-15.939003,-14.939003,-11.939003,-5.939003,-7.939003,-30.939003,-23.939003,-5.939003,8.060997,14.060997,11.060997,12.060997,11.060997,5.060997,0.06099701,-4.939003,-2.939003,1.060997,8.060997,-0.939003,-6.939003,-0.939003,-4.939003,-14.939003,-45.939003,-59.939003,-56.939003,-52.939003,-49.939003,-49.939003,-48.939003,-48.939003,-42.939003,-39.939003,-39.939003,-41.939003,-38.939003,-30.939003,-27.939003,-28.939003,-27.939003,-28.939003,-29.939003,-27.939003,-28.939003,-36.939003,-28.939003,-12.939003,41.060997,51.060997,16.060997,24.060997,30.060997,20.060997,10.060997,1.060997,17.060997,27.060997,31.060997,26.060997,22.060997,29.060997,16.060997,-1.939003,-21.939003,-11.939003,27.060997,23.060997,18.060997,21.060997,20.060997,17.060997,14.060997,17.060997,26.060997,10.060997,-17.939003,-70.939,-85.939,-85.939,-83.939,-82.939,-82.939,-78.939,-75.939,-74.939,-74.939,-75.939,-80.939,-83.939,-80.939,-83.939,-81.939,-65.939,-62.939003,-65.939,-76.939,-79.939,-74.939,-75.939,-77.939,-76.939,-76.939,-76.939,-73.939,-69.939,-63.939003,-64.939,-67.939,-68.939,-66.939,-64.939,-65.939,-66.939,-68.939,-56.939003,-44.939003,-35.939003,-28.939003,-25.939003,-43.939003,-56.939003,-63.939003,-65.939,-63.939003,-56.939003,-55.939003,-57.939003,-59.939003,-60.939003,-62.939003,-62.939003,-61.939003,-58.939003,-56.939003,-54.939003,-55.939003,-57.939003,-61.939003,-63.939003,-63.939003,-62.939003,-61.939003,-62.939003,-59.939003,-60.939003,-64.939,-67.939,-69.939,-67.939,-66.939,-64.939,-65.939,-65.939,-66.939,-69.939,-72.939,-74.939,-70.939,-64.939,-42.939003,-28.939003,-23.939003,-39.939003,-56.939003,-69.939,-75.939,-78.939,-79.939,-79.939,-79.939,-82.939,-84.939,-85.939,-86.939,-86.939,-85.939,-84.939,-85.939,-84.939,-84.939,-86.939,-91.939,-95.939,-88.939,-88.939,-92.939,-96.939,-100.939,-98.939,-91.939,-81.939,-77.939,-82.939,-98.939,-101.939,-102.939,-102.939,-102.939,-102.939,-102.939,-100.939,-95.939,-55.939003,-24.939003,12.060997,-2.939003,-22.939003,-34.939003,-38.939003,-34.939003,-27.939003,-21.939003,-19.939003,-17.939003,-16.939003,-16.939003,-17.939003,-20.939003,-23.939003,-25.939003,-25.939003,-30.939003,-38.939003,-39.939003,-41.939003,-42.939003,-40.939003,-41.939003,-44.939003,-48.939003,-50.939003,-51.939003,-53.939003,-57.939003,-61.939003,-67.939,-76.939,-83.939,-90.939,-94.939,-99.939,-102.939,-88.939,-74.939,-69.939,-79.939,-95.939,-94.939,-93.939,-93.939,-92.939,-91.939,-91.939,-92.939,-92.939,-94.939,-96.939,-96.939,-99.939,-100.939,-101.939,-100.939,-99.939,-102.939,-103.939,-101.939,-102.939,-102.939,-101.939,-100.939,-98.939,-95.939,-93.939,-93.939,-91.939,-89.939,-87.939,-86.939,-85.939,-87.939,-62.939003,-8.939003,26.060997,46.060997,31.060997,30.060997,26.060997,-38.939003,-70.939,-70.939,-72.939,-74.939,-74.939,-70.939,-65.939,-67.939,-67.939,-65.939,-62.939003,-60.939003,-60.939003,-61.939003,-62.939003,-62.939003,-60.939003,-55.939003,-57.939003,-58.939003,-58.939003,-61.939003,-65.939,-60.939003,-58.939003,-59.939003,-58.939003,-56.939003,-56.939003,-60.939003,-65.939,-64.939,-63.939003,-61.939003,-64.939,-67.939,-72.939,-73.939,-74.939,-72.939,-73.939,-78.939,-83.939,-86.939,-85.939,-86.939,-87.939,-90.939,-92.939,-95.939,-97.939,-99.939,-101.939,-100.939,-100.939,-100.939,-101.939,-103.939,-79.939,-32.939003,64.061,78.061,61.060997,86.061,48.060997,-51.939003,-85.939,-102.939,-99.939,-99.939,-100.939,-99.939,-99.939,-102.939,-97.939,-94.939,-92.939,-89.939,-85.939,-84.939,-83.939,-82.939,-79.939,-77.939,-74.939,-72.939,-70.939,-69.939,-67.939,-65.939,-64.939,-62.939003,-57.939003,-57.939003,-58.939003,-55.939003,-52.939003,-49.939003,-45.939003,-43.939003,-46.939003,-45.939003,-42.939003,-41.939003,-40.939003,-38.939003,-36.939003,-36.939003,-36.939003,-35.939003,-34.939003,-35.939003,-33.939003,-29.939003,-32.939003,-35.939003,-36.939003,-38.939003,-40.939003,-43.939003,-41.939003,-33.939003,-28.939003,-23.939003,-16.939003,-8.939003,1.060997,12.060997,14.060997,8.060997,14.060997,20.060997,25.060997,27.060997,27.060997,28.060997,40.060997,60.060997,18.060997,-3.939003,31.060997,38.060997,33.060997,48.060997,59.060997,63.060997,57.060997,54.060997,59.060997,60.060997,59.060997,59.060997,58.060997,57.060997,56.060997,55.060997,55.060997,55.060997,54.060997,54.060997,53.060997,53.060997,22.060997,9.060997,47.060997,51.060997,42.060997,40.060997,33.060997,20.060997,2.060997,-3.939003,26.060997,31.060997,25.060997,23.060997,1.060997,-41.939003,-60.939003,-69.939,-61.939003,-59.939003,-60.939003,-60.939003,-59.939003,-58.939003,-57.939003,-55.939003,-54.939003,-52.939003,-50.939003,-48.939003,-47.939003,-49.939003,-46.939003,-44.939003,-42.939003,-41.939003,-41.939003,-41.939003,-40.939003,-39.939003,-42.939003,-43.939003,-43.939003,-42.939003,-40.939003,-39.939003,-40.939003,-43.939003,-43.939003,-44.939003,-45.939003,-49.939003,-51.939003,-29.939003,-8.939003,12.060997,32.060997,40.060997,21.060997,6.060997,-10.939003,-63.939003,-81.939,-66.939,-66.939,-70.939,-78.939,-64.939,-36.939003,26.060997,52.060997,41.060997,36.060997,9.060997,-70.939,-95.939,-95.939,-98.939,-100.939,-101.939,-102.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-100.939,-99.939,-98.939,-95.939,-95.939,-96.939,-95.939,-92.939,-90.939,-88.939,-85.939,-85.939,-84.939,-82.939,-80.939,-68.939,-47.939003,-61.939003,-74.939,-73.939,-65.939,-56.939003,-66.939,-68.939,-62.939003,-60.939003,-58.939003,-57.939003,-52.939003,-47.939003,-48.939003,-47.939003,-44.939003,-48.939003,-46.939003,-29.939003,-22.939003,-21.939003,-27.939003,-33.939003,-40.939003,-40.939003,-39.939003,-39.939003,-40.939003,-41.939003,-40.939003,-38.939003,-36.939003,-36.939003,-38.939003,-41.939003,-43.939003,-44.939003,-43.939003,-43.939003,-43.939003,-42.939003,-42.939003,-45.939003,-48.939003,-52.939003,-52.939003,-53.939003,-53.939003,-54.939003,-55.939003,-57.939003,-59.939003,-60.939003,-64.939,-67.939,-68.939,-69.939,-71.939,-74.939,-77.939,-80.939,-84.939,-76.939,-54.939003,-30.939003,-17.939003,-35.939003,-34.939003,-25.939003,-17.939003,-12.939003,-9.939003,-6.939003,-4.939003,-1.939003,1.060997,5.060997,4.060997,9.060997,19.060997,28.060997,34.060997,34.060997,41.060997,49.060997,32.060997,22.060997,16.060997,16.060997,16.060997,18.060997,20.060997,20.060997,23.060997,23.060997,23.060997,24.060997,25.060997,26.060997,27.060997,28.060997,26.060997,28.060997,32.060997,34.060997,35.060997,35.060997,35.060997,35.060997,35.060997,35.060997,35.060997,35.060997,34.060997,33.060997,36.060997,39.060997,38.060997,39.060997,44.060997,40.060997,38.060997,41.060997,41.060997,38.060997,39.060997,40.060997,39.060997,38.060997,38.060997,37.060997,39.060997,42.060997,47.060997,52.060997,55.060997,53.060997,50.060997,47.060997,45.060997,43.060997,39.060997,35.060997,32.060997,34.060997,34.060997,33.060997,32.060997,31.060997,33.060997,33.060997,32.060997,31.060997,30.060997,30.060997,29.060997,27.060997,26.060997,26.060997,24.060997,26.060997,27.060997,26.060997,25.060997,25.060997,23.060997,20.060997,16.060997,17.060997,19.060997,18.060997,16.060997,14.060997,10.060997,8.060997,6.060997,5.060997,6.060997,7.060997,3.060997,-2.939003,-2.939003,-3.939003,-3.939003,-8.939003,-8.939003,4.060997,22.060997,38.060997,27.060997,15.060997,6.060997,-15.939003,-30.939003,-29.939003,-35.939003,-43.939003,-44.939003,-51.939003,-66.939,-60.939003,-52.939003,-45.939003,-48.939003,-55.939003,-52.939003,-53.939003,-56.939003,-54.939003,-53.939003,-54.939003,-56.939003,-59.939003,-58.939003,-58.939003,-58.939003,-61.939003,-63.939003,-64.939,-63.939003,-62.939003,-66.939,-69.939,-71.939,-70.939,-69.939,-70.939,-74.939,-76.939,-64.939,-53.939003,-46.939003,-47.939003,-48.939003,-45.939003,-43.939003,-43.939003,-59.939003,-53.939003,-24.939003,0.06099701,8.060997,-24.939003,-23.939003,-10.939003,-13.939003,-14.939003,-14.939003,-12.939003,-16.939003,-34.939003,-28.939003,-12.939003,-6.939003,-5.939003,-10.939003,-12.939003,-15.939003,-20.939003,-25.939003,-29.939003,-28.939003,-26.939003,-23.939003,-30.939003,-35.939003,-33.939003,-32.939003,-33.939003,-45.939003,-50.939003,-49.939003,-45.939003,-42.939003,-43.939003,-43.939003,-42.939003,-39.939003,-38.939003,-38.939003,-40.939003,-38.939003,-31.939003,-29.939003,-30.939003,-27.939003,-25.939003,-24.939003,-24.939003,-26.939003,-31.939003,-29.939003,-21.939003,12.060997,16.060997,-11.939003,-3.939003,1.060997,-4.939003,-9.939003,-13.939003,-1.939003,5.060997,8.060997,6.060997,6.060997,13.060997,2.060997,-13.939003,-27.939003,-16.939003,17.060997,13.060997,9.060997,13.060997,15.060997,15.060997,12.060997,16.060997,26.060997,7.060997,-24.939003,-76.939,-91.939,-90.939,-88.939,-87.939,-87.939,-84.939,-82.939,-83.939,-84.939,-86.939,-92.939,-94.939,-91.939,-95.939,-91.939,-67.939,-67.939,-76.939,-90.939,-93.939,-88.939,-89.939,-90.939,-91.939,-91.939,-90.939,-88.939,-85.939,-80.939,-81.939,-83.939,-84.939,-83.939,-82.939,-82.939,-82.939,-83.939,-53.939003,-23.939003,-0.939003,-0.939003,-13.939003,-52.939003,-75.939,-79.939,-80.939,-79.939,-74.939,-73.939,-74.939,-75.939,-77.939,-78.939,-77.939,-77.939,-74.939,-73.939,-72.939,-73.939,-74.939,-77.939,-77.939,-76.939,-75.939,-72.939,-70.939,-66.939,-65.939,-67.939,-71.939,-73.939,-70.939,-69.939,-67.939,-66.939,-67.939,-67.939,-67.939,-69.939,-74.939,-69.939,-60.939003,-42.939003,-32.939003,-28.939003,-44.939003,-59.939003,-67.939,-70.939,-70.939,-69.939,-69.939,-70.939,-72.939,-73.939,-74.939,-74.939,-73.939,-70.939,-69.939,-69.939,-70.939,-71.939,-72.939,-75.939,-77.939,-71.939,-70.939,-75.939,-79.939,-82.939,-79.939,-74.939,-67.939,-65.939,-69.939,-80.939,-81.939,-81.939,-80.939,-82.939,-84.939,-84.939,-82.939,-81.939,-55.939003,-35.939003,-37.939003,-40.939003,-41.939003,-30.939003,-22.939003,-18.939003,-21.939003,-24.939003,-22.939003,-16.939003,-10.939003,-15.939003,-18.939003,-19.939003,-30.939003,-35.939003,-31.939003,-35.939003,-45.939003,-42.939003,-42.939003,-44.939003,-42.939003,-41.939003,-46.939003,-51.939003,-56.939003,-56.939003,-58.939003,-64.939,-74.939,-83.939,-91.939,-97.939,-103.939,-101.939,-101.939,-103.939,-80.939,-59.939003,-48.939003,-68.939,-100.939,-99.939,-98.939,-96.939,-99.939,-101.939,-99.939,-98.939,-97.939,-95.939,-94.939,-93.939,-96.939,-99.939,-97.939,-97.939,-96.939,-98.939,-98.939,-96.939,-97.939,-97.939,-97.939,-96.939,-94.939,-87.939,-84.939,-84.939,-79.939,-76.939,-71.939,-68.939,-68.939,-73.939,-59.939003,-26.939003,-7.939003,0.06099701,-11.939003,-18.939003,-24.939003,-39.939003,-42.939003,-32.939003,-35.939003,-39.939003,-41.939003,-37.939003,-29.939003,-34.939003,-36.939003,-34.939003,-36.939003,-38.939003,-40.939003,-46.939003,-52.939003,-57.939003,-58.939003,-55.939003,-62.939003,-70.939,-74.939,-77.939,-82.939,-81.939,-83.939,-90.939,-90.939,-89.939,-90.939,-95.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-99.939,-101.939,-103.939,-102.939,-101.939,-100.939,-102.939,-102.939,-100.939,-101.939,-102.939,-102.939,-101.939,-101.939,-85.939,-44.939003,51.060997,68.061,54.060997,83.061,51.060997,-41.939003,-78.939,-97.939,-95.939,-92.939,-90.939,-91.939,-92.939,-94.939,-88.939,-83.939,-78.939,-73.939,-66.939,-60.939003,-56.939003,-53.939003,-48.939003,-43.939003,-36.939003,-33.939003,-30.939003,-30.939003,-29.939003,-27.939003,-20.939003,-12.939003,-5.939003,-5.939003,-11.939003,-11.939003,-9.939003,-4.939003,-2.939003,-5.939003,-13.939003,-18.939003,-20.939003,-21.939003,-24.939003,-29.939003,-33.939003,-38.939003,-44.939003,-49.939003,-52.939003,-57.939003,-58.939003,-57.939003,-67.939,-76.939,-76.939,-78.939,-84.939,-96.939,-87.939,-56.939003,-39.939003,-24.939003,-7.939003,11.060997,32.060997,61.060997,69.061,55.060997,58.060997,61.060997,63.060997,60.060997,56.060997,52.060997,60.060997,76.061,34.060997,8.060997,36.060997,44.060997,43.060997,51.060997,59.060997,66.061,56.060997,50.060997,57.060997,55.060997,49.060997,51.060997,52.060997,52.060997,49.060997,47.060997,44.060997,44.060997,43.060997,41.060997,39.060997,38.060997,9.060997,-5.939003,20.060997,22.060997,14.060997,10.060997,1.060997,-12.939003,-8.939003,2.060997,25.060997,30.060997,27.060997,24.060997,3.060997,-35.939003,-36.939003,-28.939003,-11.939003,-7.939003,-10.939003,-10.939003,-10.939003,-10.939003,-11.939003,-12.939003,-12.939003,-10.939003,-8.939003,-17.939003,-22.939003,-26.939003,-27.939003,-28.939003,-31.939003,-35.939003,-41.939003,-47.939003,-51.939003,-56.939003,-62.939003,-66.939,-70.939,-73.939,-76.939,-79.939,-82.939,-86.939,-88.939,-91.939,-95.939,-100.939,-98.939,-46.939003,-8.939003,15.060997,30.060997,36.060997,21.060997,8.060997,-8.939003,-71.939,-102.939,-102.939,-102.939,-102.939,-102.939,-78.939,-39.939003,46.060997,76.061,49.060997,52.060997,26.060997,-70.939,-99.939,-99.939,-99.939,-98.939,-98.939,-98.939,-98.939,-96.939,-96.939,-95.939,-95.939,-93.939,-90.939,-89.939,-86.939,-81.939,-80.939,-82.939,-79.939,-75.939,-70.939,-64.939,-60.939003,-59.939003,-57.939003,-54.939003,-49.939003,-42.939003,-34.939003,-35.939003,-36.939003,-34.939003,-32.939003,-31.939003,-31.939003,-27.939003,-18.939003,-12.939003,-9.939003,-10.939003,-9.939003,-7.939003,-9.939003,-10.939003,-7.939003,-10.939003,-15.939003,-26.939003,-37.939003,-44.939003,-36.939003,-32.939003,-33.939003,-36.939003,-39.939003,-43.939003,-48.939003,-54.939003,-55.939003,-57.939003,-59.939003,-64.939,-69.939,-77.939,-79.939,-80.939,-80.939,-82.939,-86.939,-87.939,-89.939,-93.939,-96.939,-100.939,-102.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-83.939,-50.939003,-32.939003,-24.939003,-34.939003,-31.939003,-23.939003,-19.939003,-15.939003,-12.939003,-6.939003,-1.939003,3.060997,5.060997,6.060997,4.060997,11.060997,26.060997,28.060997,30.060997,36.060997,42.060997,45.060997,21.060997,13.060997,19.060997,17.060997,15.060997,20.060997,23.060997,26.060997,26.060997,25.060997,24.060997,25.060997,26.060997,28.060997,28.060997,28.060997,27.060997,30.060997,35.060997,38.060997,39.060997,38.060997,37.060997,36.060997,37.060997,38.060997,39.060997,35.060997,32.060997,36.060997,39.060997,42.060997,40.060997,40.060997,42.060997,40.060997,40.060997,42.060997,40.060997,35.060997,37.060997,39.060997,41.060997,40.060997,40.060997,40.060997,41.060997,44.060997,49.060997,53.060997,56.060997,54.060997,51.060997,48.060997,46.060997,45.060997,41.060997,37.060997,33.060997,35.060997,36.060997,35.060997,34.060997,35.060997,34.060997,33.060997,34.060997,32.060997,31.060997,32.060997,31.060997,28.060997,27.060997,27.060997,26.060997,28.060997,29.060997,27.060997,25.060997,23.060997,24.060997,22.060997,19.060997,20.060997,21.060997,19.060997,17.060997,15.060997,10.060997,7.060997,4.060997,4.060997,5.060997,8.060997,5.060997,0.06099701,-0.939003,-1.939003,-3.939003,-1.939003,-2.939003,-11.939003,7.060997,34.060997,23.060997,17.060997,18.060997,-6.939003,-26.939003,-28.939003,-33.939003,-39.939003,-38.939003,-44.939003,-56.939003,-65.939,-67.939,-49.939003,-54.939003,-67.939,-61.939003,-64.939,-74.939,-73.939,-72.939,-77.939,-82.939,-89.939,-92.939,-94.939,-95.939,-98.939,-100.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-97.939,-74.939,-54.939003,-38.939003,-42.939003,-45.939003,-40.939003,-38.939003,-37.939003,-58.939003,-50.939003,-12.939003,0.06099701,-0.939003,-26.939003,-30.939003,-23.939003,-28.939003,-31.939003,-35.939003,-37.939003,-40.939003,-44.939003,-43.939003,-38.939003,-42.939003,-46.939003,-50.939003,-57.939003,-62.939003,-63.939003,-63.939003,-63.939003,-62.939003,-62.939003,-63.939003,-65.939,-66.939,-67.939,-60.939003,-49.939003,-43.939003,-41.939003,-43.939003,-40.939003,-37.939003,-38.939003,-38.939003,-39.939003,-37.939003,-36.939003,-35.939003,-38.939003,-39.939003,-34.939003,-32.939003,-32.939003,-27.939003,-22.939003,-17.939003,-21.939003,-25.939003,-28.939003,-32.939003,-36.939003,-36.939003,-43.939003,-57.939003,-53.939003,-49.939003,-49.939003,-46.939003,-41.939003,-42.939003,-41.939003,-39.939003,-34.939003,-30.939003,-27.939003,-32.939003,-39.939003,-40.939003,-33.939003,-19.939003,-22.939003,-24.939003,-21.939003,-17.939003,-13.939003,-15.939003,-9.939003,2.060997,-15.939003,-40.939003,-74.939,-83.939,-81.939,-78.939,-78.939,-81.939,-81.939,-82.939,-86.939,-88.939,-90.939,-94.939,-95.939,-93.939,-97.939,-91.939,-65.939,-69.939,-85.939,-95.939,-100.939,-100.939,-99.939,-98.939,-100.939,-100.939,-100.939,-100.939,-99.939,-97.939,-99.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-101.939,-51.939003,0.06099701,43.060997,36.060997,3.060997,-66.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-103.939,-101.939,-100.939,-100.939,-97.939,-94.939,-89.939,-86.939,-84.939,-88.939,-90.939,-86.939,-85.939,-84.939,-81.939,-80.939,-80.939,-76.939,-76.939,-84.939,-76.939,-59.939003,-33.939003,-19.939003,-17.939003,-42.939003,-65.939,-74.939,-73.939,-69.939,-67.939,-66.939,-66.939,-67.939,-67.939,-68.939,-66.939,-65.939,-60.939003,-57.939003,-56.939003,-60.939003,-63.939003,-62.939003,-61.939003,-60.939003,-55.939003,-55.939003,-58.939003,-61.939003,-63.939003,-58.939003,-55.939003,-52.939003,-54.939003,-55.939003,-59.939003,-56.939003,-54.939003,-53.939003,-56.939003,-62.939003,-61.939003,-60.939003,-60.939003,-56.939003,-53.939003,-27.939003,-25.939003,-22.939003,-19.939003,-17.939003,-19.939003,-15.939003,-12.939003,-10.939003,-12.939003,-17.939003,-25.939003,-29.939003,-31.939003,-37.939003,-39.939003,-36.939003,-39.939003,-44.939003,-43.939003,-45.939003,-49.939003,-51.939003,-52.939003,-51.939003,-52.939003,-53.939003,-62.939003,-73.939,-86.939,-91.939,-95.939,-99.939,-101.939,-103.939,-102.939,-102.939,-103.939,-73.939,-46.939003,-37.939003,-63.939003,-102.939,-100.939,-99.939,-100.939,-98.939,-97.939,-95.939,-94.939,-92.939,-88.939,-83.939,-78.939,-77.939,-76.939,-74.939,-74.939,-72.939,-67.939,-63.939003,-62.939003,-63.939003,-60.939003,-52.939003,-49.939003,-50.939003,-51.939003,-49.939003,-45.939003,-44.939003,-43.939003,-40.939003,-41.939003,-44.939003,-47.939003,-41.939003,-27.939003,-10.939003,-0.939003,-4.939003,0.06099701,4.060997,-30.939003,-53.939003,-65.939,-67.939,-69.939,-72.939,-72.939,-72.939,-71.939,-72.939,-75.939,-78.939,-80.939,-81.939,-83.939,-86.939,-88.939,-86.939,-82.939,-87.939,-91.939,-91.939,-92.939,-95.939,-96.939,-96.939,-98.939,-99.939,-98.939,-99.939,-100.939,-101.939,-100.939,-100.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-103.939,-103.939,-102.939,-102.939,-103.939,-102.939,-98.939,-98.939,-97.939,-89.939,-83.939,-80.939,-70.939,-45.939003,14.060997,25.060997,15.060997,18.060997,2.060997,-30.939003,-43.939003,-48.939003,-40.939003,-36.939003,-34.939003,-31.939003,-30.939003,-31.939003,-31.939003,-31.939003,-31.939003,-30.939003,-27.939003,-25.939003,-27.939003,-30.939003,-31.939003,-30.939003,-29.939003,-28.939003,-29.939003,-32.939003,-36.939003,-42.939003,-45.939003,-46.939003,-46.939003,-50.939003,-57.939003,-62.939003,-65.939,-66.939,-66.939,-67.939,-70.939,-73.939,-75.939,-76.939,-77.939,-78.939,-80.939,-81.939,-83.939,-85.939,-86.939,-88.939,-88.939,-88.939,-91.939,-94.939,-94.939,-93.939,-94.939,-100.939,-92.939,-69.939,-44.939003,-21.939003,-6.939003,12.060997,33.060997,63.060997,72.061,60.060997,57.060997,56.060997,55.060997,53.060997,51.060997,45.060997,49.060997,61.060997,24.060997,-1.939003,13.060997,19.060997,20.060997,20.060997,23.060997,30.060997,17.060997,7.060997,6.060997,6.060997,5.060997,2.060997,-2.939003,-6.939003,-7.939003,-7.939003,-11.939003,-14.939003,-16.939003,-20.939003,-21.939003,-20.939003,-18.939003,-16.939003,-17.939003,-21.939003,-26.939003,-21.939003,-21.939003,-27.939003,-15.939003,0.06099701,20.060997,27.060997,27.060997,23.060997,1.060997,-36.939003,-48.939003,-53.939003,-46.939003,-49.939003,-55.939003,-59.939003,-61.939003,-61.939003,-64.939,-68.939,-71.939,-72.939,-71.939,-74.939,-76.939,-77.939,-77.939,-77.939,-79.939,-80.939,-82.939,-84.939,-85.939,-85.939,-88.939,-91.939,-92.939,-93.939,-94.939,-95.939,-95.939,-95.939,-97.939,-99.939,-100.939,-102.939,-98.939,-49.939003,-10.939003,20.060997,33.060997,37.060997,24.060997,12.060997,-3.939003,-67.939,-97.939,-95.939,-93.939,-91.939,-93.939,-73.939,-44.939003,11.060997,29.060997,11.060997,8.060997,-6.939003,-49.939003,-62.939003,-62.939003,-56.939003,-51.939003,-49.939003,-47.939003,-45.939003,-43.939003,-42.939003,-41.939003,-39.939003,-36.939003,-32.939003,-31.939003,-31.939003,-30.939003,-31.939003,-33.939003,-32.939003,-31.939003,-32.939003,-29.939003,-28.939003,-31.939003,-34.939003,-37.939003,-34.939003,-32.939003,-30.939003,-32.939003,-35.939003,-35.939003,-36.939003,-38.939003,-50.939003,-54.939003,-51.939003,-51.939003,-53.939003,-54.939003,-59.939003,-65.939,-63.939003,-62.939003,-60.939003,-67.939,-60.939003,-16.939003,16.060997,42.060997,45.060997,22.060997,-27.939003,-59.939003,-81.939,-82.939,-83.939,-87.939,-87.939,-88.939,-88.939,-90.939,-92.939,-93.939,-94.939,-95.939,-95.939,-96.939,-97.939,-98.939,-98.939,-100.939,-101.939,-102.939,-101.939,-102.939,-102.939,-100.939,-98.939,-96.939,-93.939,-90.939,-89.939,-86.939,-82.939,-79.939,-76.939,-71.939,-68.939,-65.939,-69.939,-58.939003,-35.939003,-27.939003,-25.939003,-32.939003,-28.939003,-21.939003,-17.939003,-14.939003,-13.939003,-9.939003,-5.939003,-0.939003,2.060997,5.060997,9.060997,15.060997,24.060997,27.060997,33.060997,48.060997,45.060997,35.060997,22.060997,18.060997,24.060997,21.060997,19.060997,21.060997,22.060997,24.060997,25.060997,26.060997,27.060997,27.060997,27.060997,29.060997,30.060997,31.060997,30.060997,31.060997,35.060997,37.060997,39.060997,38.060997,37.060997,36.060997,37.060997,38.060997,39.060997,35.060997,33.060997,37.060997,40.060997,42.060997,42.060997,42.060997,42.060997,42.060997,43.060997,43.060997,41.060997,36.060997,39.060997,41.060997,42.060997,41.060997,41.060997,41.060997,42.060997,43.060997,48.060997,52.060997,55.060997,56.060997,55.060997,51.060997,47.060997,43.060997,36.060997,34.060997,36.060997,38.060997,38.060997,36.060997,35.060997,34.060997,31.060997,30.060997,31.060997,30.060997,30.060997,31.060997,31.060997,31.060997,29.060997,28.060997,26.060997,28.060997,28.060997,27.060997,24.060997,22.060997,23.060997,23.060997,22.060997,21.060997,19.060997,16.060997,16.060997,17.060997,12.060997,8.060997,5.060997,6.060997,6.060997,7.060997,6.060997,3.060997,2.060997,1.060997,0.06099701,-2.939003,-7.939003,-13.939003,-0.939003,20.060997,30.060997,32.060997,25.060997,1.060997,-19.939003,-30.939003,-31.939003,-29.939003,-33.939003,-40.939003,-49.939003,-55.939003,-64.939,-77.939,-85.939,-91.939,-89.939,-90.939,-93.939,-93.939,-93.939,-94.939,-95.939,-96.939,-99.939,-99.939,-99.939,-97.939,-95.939,-95.939,-91.939,-86.939,-86.939,-85.939,-82.939,-81.939,-80.939,-80.939,-76.939,-70.939,-59.939003,-49.939003,-41.939003,-43.939003,-46.939003,-45.939003,-45.939003,-44.939003,-47.939003,-47.939003,-43.939003,-42.939003,-41.939003,-41.939003,-47.939003,-54.939003,-53.939003,-51.939003,-49.939003,-48.939003,-48.939003,-48.939003,-48.939003,-47.939003,-43.939003,-39.939003,-35.939003,-39.939003,-41.939003,-38.939003,-35.939003,-32.939003,-27.939003,-23.939003,-20.939003,-20.939003,-17.939003,-7.939003,-8.939003,-17.939003,-46.939003,-60.939003,-59.939003,-59.939003,-58.939003,-57.939003,-53.939003,-47.939003,-40.939003,-37.939003,-36.939003,-40.939003,-40.939003,-31.939003,-27.939003,-27.939003,-23.939003,-21.939003,-20.939003,-20.939003,-24.939003,-35.939003,-30.939003,-18.939003,-8.939003,-14.939003,-34.939003,-33.939003,-31.939003,-34.939003,-35.939003,-35.939003,-35.939003,-33.939003,-29.939003,-29.939003,-29.939003,-28.939003,-32.939003,-38.939003,-42.939003,-42.939003,-36.939003,-37.939003,-37.939003,-35.939003,-34.939003,-34.939003,-40.939003,-39.939003,-29.939003,-33.939003,-41.939003,-54.939003,-58.939003,-58.939003,-57.939003,-58.939003,-62.939003,-60.939003,-59.939003,-59.939003,-59.939003,-60.939003,-62.939003,-62.939003,-61.939003,-64.939,-64.939,-56.939003,-56.939003,-61.939003,-66.939,-68.939,-66.939,-68.939,-70.939,-69.939,-69.939,-70.939,-73.939,-74.939,-71.939,-72.939,-74.939,-75.939,-75.939,-73.939,-74.939,-75.939,-78.939,-54.939003,-27.939003,-2.939003,-5.939003,-22.939003,-61.939003,-82.939,-86.939,-84.939,-83.939,-86.939,-86.939,-86.939,-86.939,-88.939,-91.939,-93.939,-94.939,-95.939,-95.939,-96.939,-95.939,-96.939,-98.939,-100.939,-101.939,-101.939,-100.939,-99.939,-97.939,-95.939,-93.939,-95.939,-96.939,-93.939,-94.939,-95.939,-95.939,-95.939,-94.939,-92.939,-93.939,-96.939,-81.939,-55.939003,-8.939003,13.060997,10.060997,-37.939003,-78.939,-91.939,-92.939,-90.939,-90.939,-90.939,-91.939,-91.939,-91.939,-91.939,-91.939,-90.939,-87.939,-84.939,-83.939,-87.939,-90.939,-89.939,-88.939,-87.939,-86.939,-87.939,-88.939,-88.939,-88.939,-87.939,-78.939,-66.939,-71.939,-77.939,-86.939,-80.939,-76.939,-79.939,-78.939,-77.939,-79.939,-76.939,-68.939,-50.939003,-36.939003,-20.939003,-18.939003,-13.939003,-11.939003,-11.939003,-16.939003,-12.939003,-9.939003,-9.939003,-15.939003,-24.939003,-31.939003,-34.939003,-35.939003,-39.939003,-42.939003,-40.939003,-42.939003,-45.939003,-47.939003,-49.939003,-52.939003,-55.939003,-58.939003,-59.939003,-62.939003,-66.939,-76.939,-87.939,-99.939,-99.939,-97.939,-96.939,-93.939,-90.939,-89.939,-87.939,-85.939,-60.939003,-39.939003,-32.939003,-52.939003,-81.939,-80.939,-80.939,-81.939,-76.939,-73.939,-72.939,-71.939,-69.939,-66.939,-62.939003,-60.939003,-57.939003,-56.939003,-58.939003,-59.939003,-58.939003,-53.939003,-50.939003,-51.939003,-55.939003,-54.939003,-46.939003,-44.939003,-46.939003,-51.939003,-51.939003,-47.939003,-48.939003,-49.939003,-47.939003,-49.939003,-54.939003,-56.939003,-47.939003,-29.939003,4.060997,26.060997,20.060997,29.060997,37.060997,-18.939003,-61.939003,-87.939,-89.939,-89.939,-91.939,-93.939,-95.939,-94.939,-94.939,-97.939,-100.939,-102.939,-102.939,-101.939,-102.939,-103.939,-101.939,-97.939,-100.939,-102.939,-101.939,-101.939,-102.939,-103.939,-103.939,-102.939,-102.939,-101.939,-100.939,-100.939,-98.939,-96.939,-96.939,-98.939,-95.939,-94.939,-93.939,-91.939,-89.939,-88.939,-87.939,-87.939,-84.939,-81.939,-77.939,-76.939,-75.939,-76.939,-75.939,-74.939,-72.939,-70.939,-66.939,-66.939,-67.939,-59.939003,-52.939003,-48.939003,-45.939003,-36.939003,-11.939003,-6.939003,-11.939003,-16.939003,-20.939003,-23.939003,-28.939003,-30.939003,-25.939003,-25.939003,-27.939003,-23.939003,-23.939003,-24.939003,-28.939003,-32.939003,-35.939003,-37.939003,-37.939003,-38.939003,-40.939003,-45.939003,-47.939003,-49.939003,-49.939003,-50.939003,-51.939003,-55.939003,-60.939003,-67.939,-71.939,-75.939,-78.939,-82.939,-87.939,-92.939,-96.939,-99.939,-100.939,-100.939,-100.939,-102.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-100.939,-99.939,-97.939,-95.939,-94.939,-95.939,-91.939,-87.939,-89.939,-82.939,-66.939,-43.939003,-22.939003,-8.939003,3.060997,15.060997,34.060997,38.060997,30.060997,26.060997,23.060997,22.060997,20.060997,18.060997,13.060997,14.060997,20.060997,0.06099701,-14.939003,-6.939003,-4.939003,-5.939003,-6.939003,-4.939003,0.06099701,-7.939003,-13.939003,-14.939003,-12.939003,-11.939003,-12.939003,-15.939003,-17.939003,-15.939003,-14.939003,-16.939003,-16.939003,-16.939003,-20.939003,-20.939003,-17.939003,-14.939003,-11.939003,-11.939003,-10.939003,-11.939003,-5.939003,-4.939003,-7.939003,-10.939003,-4.939003,16.060997,24.060997,26.060997,21.060997,1.060997,-35.939003,-60.939003,-77.939,-75.939,-80.939,-85.939,-89.939,-91.939,-92.939,-95.939,-98.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-101.939,-101.939,-102.939,-103.939,-103.939,-102.939,-100.939,-96.939,-97.939,-98.939,-98.939,-96.939,-95.939,-91.939,-89.939,-88.939,-88.939,-87.939,-84.939,-86.939,-85.939,-44.939003,-8.939003,21.060997,34.060997,38.060997,25.060997,12.060997,-4.939003,-61.939003,-80.939,-63.939003,-60.939003,-60.939003,-61.939003,-53.939003,-40.939003,-17.939003,-10.939003,-18.939003,-20.939003,-24.939003,-33.939003,-37.939003,-39.939003,-32.939003,-30.939003,-30.939003,-29.939003,-28.939003,-28.939003,-28.939003,-30.939003,-29.939003,-29.939003,-26.939003,-27.939003,-29.939003,-31.939003,-34.939003,-36.939003,-35.939003,-37.939003,-41.939003,-40.939003,-40.939003,-44.939003,-48.939003,-51.939003,-51.939003,-48.939003,-42.939003,-48.939003,-56.939003,-57.939003,-56.939003,-57.939003,-71.939,-78.939,-79.939,-81.939,-84.939,-84.939,-90.939,-96.939,-94.939,-92.939,-90.939,-97.939,-82.939,-13.939003,39.060997,79.061,75.061,39.060997,-26.939003,-70.939,-99.939,-95.939,-94.939,-94.939,-93.939,-91.939,-89.939,-86.939,-85.939,-84.939,-84.939,-84.939,-82.939,-80.939,-78.939,-77.939,-76.939,-76.939,-76.939,-76.939,-73.939,-72.939,-72.939,-69.939,-67.939,-65.939,-61.939003,-58.939003,-59.939003,-57.939003,-52.939003,-51.939003,-49.939003,-42.939003,-41.939003,-42.939003,-46.939003,-43.939003,-31.939003,-26.939003,-24.939003,-27.939003,-24.939003,-19.939003,-16.939003,-14.939003,-12.939003,-8.939003,-5.939003,-1.939003,1.060997,5.060997,13.060997,19.060997,25.060997,31.060997,38.060997,51.060997,41.060997,23.060997,20.060997,21.060997,24.060997,22.060997,21.060997,24.060997,24.060997,24.060997,25.060997,26.060997,29.060997,29.060997,29.060997,31.060997,32.060997,32.060997,32.060997,33.060997,36.060997,38.060997,39.060997,38.060997,38.060997,38.060997,38.060997,38.060997,38.060997,36.060997,36.060997,38.060997,40.060997,43.060997,43.060997,44.060997,43.060997,44.060997,44.060997,43.060997,41.060997,39.060997,42.060997,43.060997,42.060997,42.060997,43.060997,44.060997,45.060997,45.060997,49.060997,53.060997,55.060997,57.060997,57.060997,52.060997,49.060997,45.060997,37.060997,35.060997,38.060997,38.060997,38.060997,36.060997,35.060997,33.060997,31.060997,31.060997,31.060997,31.060997,31.060997,32.060997,31.060997,32.060997,30.060997,28.060997,25.060997,28.060997,29.060997,27.060997,25.060997,23.060997,23.060997,23.060997,21.060997,19.060997,16.060997,13.060997,15.060997,17.060997,13.060997,10.060997,7.060997,7.060997,6.060997,6.060997,5.060997,5.060997,4.060997,3.060997,2.060997,-2.939003,-6.939003,-11.939003,-4.939003,7.060997,24.060997,28.060997,19.060997,0.06099701,-18.939003,-32.939003,-32.939003,-27.939003,-32.939003,-36.939003,-40.939003,-45.939003,-56.939003,-83.939,-91.939,-91.939,-87.939,-85.939,-85.939,-84.939,-83.939,-83.939,-83.939,-80.939,-79.939,-78.939,-78.939,-75.939,-74.939,-72.939,-67.939,-62.939003,-64.939,-64.939,-62.939003,-61.939003,-61.939003,-61.939003,-58.939003,-54.939003,-51.939003,-48.939003,-45.939003,-46.939003,-47.939003,-47.939003,-47.939003,-46.939003,-47.939003,-47.939003,-44.939003,-43.939003,-42.939003,-39.939003,-42.939003,-46.939003,-44.939003,-39.939003,-33.939003,-33.939003,-34.939003,-40.939003,-37.939003,-29.939003,-21.939003,-15.939003,-13.939003,-16.939003,-17.939003,-14.939003,-13.939003,-11.939003,-7.939003,-4.939003,-1.939003,-2.939003,0.06099701,10.060997,5.060997,-10.939003,-44.939003,-60.939003,-59.939003,-58.939003,-57.939003,-56.939003,-52.939003,-46.939003,-41.939003,-37.939003,-37.939003,-40.939003,-40.939003,-31.939003,-27.939003,-24.939003,-22.939003,-21.939003,-22.939003,-21.939003,-25.939003,-40.939003,-26.939003,-1.939003,19.060997,17.060997,-6.939003,-4.939003,-2.939003,-9.939003,-12.939003,-14.939003,-9.939003,-6.939003,-1.939003,-3.939003,-5.939003,-6.939003,-16.939003,-29.939003,-32.939003,-28.939003,-14.939003,-15.939003,-16.939003,-12.939003,-14.939003,-17.939003,-25.939003,-23.939003,-12.939003,-25.939003,-41.939003,-58.939003,-64.939,-64.939,-61.939003,-62.939003,-66.939,-64.939,-62.939003,-61.939003,-61.939003,-63.939003,-64.939,-64.939,-63.939003,-65.939,-64.939,-57.939003,-57.939003,-61.939003,-64.939,-65.939,-63.939003,-65.939,-67.939,-64.939,-64.939,-65.939,-67.939,-67.939,-64.939,-64.939,-66.939,-67.939,-66.939,-64.939,-63.939003,-65.939,-68.939,-55.939003,-41.939003,-27.939003,-27.939003,-34.939003,-56.939003,-68.939,-72.939,-71.939,-70.939,-70.939,-71.939,-71.939,-72.939,-73.939,-74.939,-75.939,-76.939,-77.939,-79.939,-80.939,-79.939,-80.939,-81.939,-83.939,-85.939,-85.939,-85.939,-86.939,-86.939,-86.939,-86.939,-86.939,-86.939,-83.939,-84.939,-88.939,-89.939,-88.939,-84.939,-86.939,-88.939,-89.939,-76.939,-53.939003,-13.939003,5.060997,3.060997,-38.939003,-74.939,-86.939,-89.939,-87.939,-89.939,-91.939,-92.939,-93.939,-93.939,-93.939,-93.939,-94.939,-91.939,-89.939,-87.939,-91.939,-94.939,-96.939,-96.939,-96.939,-97.939,-98.939,-99.939,-97.939,-96.939,-98.939,-86.939,-72.939,-78.939,-86.939,-97.939,-93.939,-91.939,-94.939,-93.939,-90.939,-92.939,-86.939,-68.939,-42.939003,-23.939003,-17.939003,-17.939003,-15.939003,-6.939003,-4.939003,-9.939003,-11.939003,-14.939003,-19.939003,-24.939003,-29.939003,-32.939003,-32.939003,-30.939003,-37.939003,-43.939003,-43.939003,-44.939003,-47.939003,-53.939003,-56.939003,-54.939003,-55.939003,-59.939003,-68.939,-81.939,-95.939,-99.939,-102.939,-102.939,-96.939,-90.939,-81.939,-73.939,-65.939,-60.939003,-54.939003,-49.939003,-43.939003,-37.939003,-33.939003,-35.939003,-38.939003,-41.939003,-41.939003,-40.939003,-33.939003,-28.939003,-29.939003,-28.939003,-26.939003,-27.939003,-31.939003,-38.939003,-38.939003,-40.939003,-48.939003,-53.939003,-54.939003,-57.939003,-59.939003,-64.939,-72.939,-79.939,-80.939,-81.939,-82.939,-87.939,-89.939,-88.939,-93.939,-95.939,-92.939,-94.939,-99.939,-99.939,-77.939,-34.939003,37.060997,83.061,62.060997,68.061,76.061,-5.939003,-64.939,-98.939,-100.939,-99.939,-100.939,-100.939,-99.939,-101.939,-101.939,-100.939,-102.939,-103.939,-101.939,-100.939,-99.939,-102.939,-103.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-100.939,-97.939,-95.939,-95.939,-95.939,-93.939,-91.939,-87.939,-80.939,-75.939,-72.939,-66.939,-60.939003,-58.939003,-56.939003,-54.939003,-47.939003,-37.939003,-25.939003,-21.939003,-20.939003,-21.939003,-20.939003,-17.939003,-11.939003,-6.939003,-4.939003,-7.939003,-12.939003,-11.939003,-8.939003,-7.939003,-12.939003,-19.939003,-26.939003,-27.939003,-25.939003,-19.939003,-17.939003,-20.939003,-33.939003,-45.939003,-50.939003,-60.939003,-69.939,-68.939,-70.939,-75.939,-81.939,-87.939,-90.939,-94.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-99.939,-99.939,-100.939,-100.939,-100.939,-101.939,-101.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-103.939,-103.939,-102.939,-101.939,-101.939,-101.939,-99.939,-96.939,-94.939,-90.939,-85.939,-80.939,-77.939,-78.939,-71.939,-63.939003,-64.939,-59.939003,-47.939003,-38.939003,-27.939003,-15.939003,-15.939003,-19.939003,-26.939003,-30.939003,-33.939003,-34.939003,-35.939003,-36.939003,-40.939003,-43.939003,-43.939003,-44.939003,-44.939003,-37.939003,-30.939003,-22.939003,-27.939003,-35.939003,-27.939003,-23.939003,-24.939003,-20.939003,-14.939003,-6.939003,-3.939003,-1.939003,6.060997,13.060997,19.060997,23.060997,27.060997,30.060997,37.060997,43.060997,41.060997,42.060997,45.060997,20.060997,11.060997,39.060997,54.060997,60.060997,57.060997,52.060997,47.060997,8.060997,-13.939003,13.060997,22.060997,23.060997,20.060997,2.060997,-32.939003,-72.939,-101.939,-100.939,-101.939,-100.939,-102.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-100.939,-98.939,-101.939,-102.939,-102.939,-99.939,-95.939,-90.939,-89.939,-89.939,-87.939,-83.939,-78.939,-68.939,-64.939,-63.939003,-59.939003,-54.939003,-46.939003,-51.939003,-59.939003,-29.939003,-3.939003,20.060997,33.060997,37.060997,25.060997,8.060997,-11.939003,-52.939003,-50.939003,-7.939003,-5.939003,-9.939003,-7.939003,-16.939003,-28.939003,-40.939003,-44.939003,-40.939003,-34.939003,-28.939003,-20.939003,-24.939003,-31.939003,-29.939003,-33.939003,-41.939003,-44.939003,-46.939003,-50.939003,-54.939003,-61.939003,-65.939,-70.939,-73.939,-76.939,-80.939,-85.939,-89.939,-91.939,-90.939,-92.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-98.939,-89.939,-69.939,-83.939,-97.939,-99.939,-93.939,-88.939,-96.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-101.939,-100.939,-95.939,-100.939,-83.939,-18.939003,30.060997,66.061,52.060997,20.060997,-30.939003,-68.939,-93.939,-84.939,-79.939,-76.939,-73.939,-68.939,-61.939003,-53.939003,-49.939003,-49.939003,-49.939003,-48.939003,-41.939003,-35.939003,-27.939003,-24.939003,-22.939003,-23.939003,-22.939003,-21.939003,-18.939003,-15.939003,-11.939003,-9.939003,-8.939003,-8.939003,-6.939003,-6.939003,-13.939003,-15.939003,-13.939003,-16.939003,-18.939003,-13.939003,-20.939003,-30.939003,-33.939003,-36.939003,-36.939003,-27.939003,-19.939003,-18.939003,-17.939003,-17.939003,-17.939003,-15.939003,-10.939003,-4.939003,-0.939003,1.060997,2.060997,5.060997,16.060997,24.060997,29.060997,38.060997,45.060997,45.060997,28.060997,9.060997,17.060997,20.060997,19.060997,20.060997,22.060997,28.060997,29.060997,27.060997,26.060997,27.060997,30.060997,30.060997,31.060997,33.060997,34.060997,33.060997,34.060997,36.060997,38.060997,39.060997,39.060997,38.060997,39.060997,41.060997,41.060997,39.060997,36.060997,39.060997,41.060997,39.060997,41.060997,44.060997,44.060997,45.060997,46.060997,45.060997,44.060997,42.060997,42.060997,43.060997,45.060997,44.060997,41.060997,42.060997,45.060997,49.060997,50.060997,48.060997,52.060997,54.060997,55.060997,57.060997,58.060997,53.060997,52.060997,50.060997,42.060997,39.060997,38.060997,37.060997,36.060997,36.060997,34.060997,32.060997,34.060997,34.060997,33.060997,34.060997,34.060997,33.060997,32.060997,32.060997,29.060997,27.060997,24.060997,28.060997,30.060997,28.060997,27.060997,26.060997,23.060997,20.060997,18.060997,16.060997,13.060997,10.060997,12.060997,16.060997,13.060997,11.060997,8.060997,6.060997,5.060997,3.060997,4.060997,6.060997,5.060997,3.060997,1.060997,0.06099701,-1.939003,-5.939003,-5.939003,-4.939003,3.060997,5.060997,0.06099701,-10.939003,-22.939003,-35.939003,-37.939003,-34.939003,-35.939003,-33.939003,-30.939003,-35.939003,-45.939003,-66.939,-70.939,-67.939,-55.939003,-50.939003,-50.939003,-46.939003,-43.939003,-44.939003,-44.939003,-41.939003,-33.939003,-30.939003,-31.939003,-34.939003,-37.939003,-33.939003,-30.939003,-27.939003,-34.939003,-39.939003,-41.939003,-41.939003,-42.939003,-43.939003,-46.939003,-49.939003,-48.939003,-49.939003,-50.939003,-50.939003,-48.939003,-46.939003,-44.939003,-44.939003,-58.939003,-49.939003,-17.939003,-4.939003,-2.939003,-20.939003,-15.939003,-1.939003,0.06099701,4.060997,11.060997,9.060997,0.06099701,-20.939003,-8.939003,15.060997,25.060997,25.060997,15.060997,11.060997,7.060997,6.060997,2.060997,0.06099701,-2.939003,-3.939003,-4.939003,-10.939003,-14.939003,-12.939003,-18.939003,-29.939003,-39.939003,-43.939003,-42.939003,-39.939003,-36.939003,-35.939003,-35.939003,-36.939003,-38.939003,-38.939003,-36.939003,-38.939003,-38.939003,-35.939003,-30.939003,-25.939003,-23.939003,-22.939003,-22.939003,-22.939003,-26.939003,-41.939003,-21.939003,13.060997,47.060997,52.060997,26.060997,33.060997,37.060997,23.060997,20.060997,22.060997,34.060997,40.060997,43.060997,44.060997,43.060997,37.060997,14.060997,-12.939003,-11.939003,8.060997,47.060997,43.060997,39.060997,46.060997,42.060997,35.060997,31.060997,37.060997,52.060997,9.060997,-40.939003,-85.939,-99.939,-98.939,-92.939,-90.939,-93.939,-92.939,-92.939,-92.939,-95.939,-99.939,-99.939,-99.939,-99.939,-98.939,-90.939,-69.939,-72.939,-85.939,-89.939,-91.939,-90.939,-90.939,-89.939,-84.939,-84.939,-85.939,-82.939,-79.939,-76.939,-75.939,-75.939,-77.939,-75.939,-72.939,-68.939,-68.939,-70.939,-55.939003,-40.939003,-31.939003,-29.939003,-33.939003,-49.939003,-58.939003,-60.939003,-63.939003,-63.939003,-55.939003,-55.939003,-57.939003,-59.939003,-58.939003,-53.939003,-49.939003,-47.939003,-48.939003,-52.939003,-55.939003,-54.939003,-53.939003,-50.939003,-50.939003,-50.939003,-52.939003,-53.939003,-54.939003,-57.939003,-61.939003,-63.939003,-61.939003,-60.939003,-55.939003,-57.939003,-61.939003,-63.939003,-60.939003,-51.939003,-57.939003,-63.939003,-64.939,-60.939003,-53.939003,-48.939003,-43.939003,-39.939003,-47.939003,-55.939003,-61.939003,-63.939003,-62.939003,-65.939,-68.939,-71.939,-72.939,-72.939,-74.939,-74.939,-75.939,-72.939,-70.939,-69.939,-72.939,-76.939,-81.939,-85.939,-87.939,-86.939,-87.939,-91.939,-88.939,-87.939,-91.939,-81.939,-70.939,-75.939,-83.939,-93.939,-96.939,-98.939,-99.939,-100.939,-99.939,-101.939,-89.939,-61.939003,-33.939003,-13.939003,-16.939003,-16.939003,-14.939003,-10.939003,-10.939003,-15.939003,-21.939003,-27.939003,-30.939003,-33.939003,-36.939003,-38.939003,-38.939003,-33.939003,-39.939003,-44.939003,-45.939003,-46.939003,-49.939003,-52.939003,-56.939003,-61.939003,-66.939,-71.939,-79.939,-86.939,-93.939,-93.939,-90.939,-84.939,-73.939,-64.939,-59.939003,-55.939003,-52.939003,-54.939003,-54.939003,-53.939003,-46.939003,-38.939003,-27.939003,-37.939003,-54.939003,-60.939003,-63.939003,-64.939,-62.939003,-61.939003,-64.939,-65.939,-65.939,-64.939,-67.939,-73.939,-71.939,-71.939,-76.939,-79.939,-81.939,-83.939,-84.939,-85.939,-90.939,-93.939,-94.939,-94.939,-93.939,-94.939,-94.939,-92.939,-95.939,-97.939,-97.939,-97.939,-97.939,-100.939,-84.939,-50.939003,33.060997,88.061,67.061,75.061,85.061,7.060997,-53.939003,-98.939,-102.939,-102.939,-102.939,-102.939,-100.939,-101.939,-99.939,-95.939,-94.939,-92.939,-90.939,-88.939,-86.939,-86.939,-82.939,-76.939,-73.939,-70.939,-69.939,-66.939,-65.939,-62.939003,-60.939003,-56.939003,-51.939003,-47.939003,-47.939003,-47.939003,-46.939003,-43.939003,-40.939003,-36.939003,-35.939003,-34.939003,-37.939003,-36.939003,-34.939003,-31.939003,-32.939003,-37.939003,-40.939003,-40.939003,-35.939003,-34.939003,-35.939003,-39.939003,-41.939003,-42.939003,-44.939003,-46.939003,-45.939003,-47.939003,-49.939003,-55.939003,-57.939003,-56.939003,-61.939003,-47.939003,8.060997,26.060997,30.060997,44.060997,33.060997,-1.939003,-45.939003,-78.939,-81.939,-84.939,-87.939,-89.939,-90.939,-92.939,-95.939,-98.939,-100.939,-102.939,-103.939,-103.939,-102.939,-101.939,-102.939,-102.939,-103.939,-103.939,-102.939,-103.939,-102.939,-102.939,-101.939,-98.939,-95.939,-92.939,-90.939,-88.939,-87.939,-84.939,-82.939,-78.939,-74.939,-71.939,-70.939,-69.939,-66.939,-61.939003,-57.939003,-53.939003,-51.939003,-48.939003,-45.939003,-45.939003,-42.939003,-37.939003,-35.939003,-34.939003,-39.939003,-36.939003,-32.939003,-33.939003,-32.939003,-26.939003,-28.939003,-27.939003,-18.939003,-15.939003,-13.939003,-12.939003,-9.939003,-7.939003,-6.939003,-6.939003,-4.939003,-2.939003,-0.939003,1.060997,8.060997,19.060997,6.060997,-3.939003,0.06099701,3.060997,7.060997,16.060997,24.060997,32.060997,30.060997,29.060997,32.060997,33.060997,34.060997,40.060997,44.060997,48.060997,48.060997,49.060997,50.060997,53.060997,55.060997,56.060997,58.060997,62.060997,29.060997,14.060997,51.060997,62.060997,62.060997,61.060997,59.060997,56.060997,13.060997,-12.939003,12.060997,22.060997,25.060997,22.060997,5.060997,-25.939003,-67.939,-97.939,-90.939,-87.939,-85.939,-84.939,-82.939,-80.939,-74.939,-70.939,-68.939,-68.939,-67.939,-65.939,-62.939003,-61.939003,-56.939003,-52.939003,-52.939003,-50.939003,-48.939003,-46.939003,-43.939003,-39.939003,-38.939003,-38.939003,-40.939003,-36.939003,-31.939003,-31.939003,-31.939003,-30.939003,-27.939003,-26.939003,-28.939003,-38.939003,-47.939003,-26.939003,-5.939003,17.060997,30.060997,34.060997,24.060997,13.060997,-1.939003,-53.939003,-69.939,-52.939003,-47.939003,-47.939003,-55.939003,-54.939003,-44.939003,-3.939003,12.060997,4.060997,11.060997,2.060997,-49.939003,-68.939,-71.939,-71.939,-72.939,-76.939,-78.939,-79.939,-81.939,-83.939,-86.939,-88.939,-90.939,-92.939,-93.939,-95.939,-97.939,-99.939,-100.939,-100.939,-101.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-101.939,-90.939,-67.939,-77.939,-86.939,-86.939,-77.939,-67.939,-79.939,-83.939,-77.939,-71.939,-68.939,-66.939,-66.939,-65.939,-60.939003,-55.939003,-49.939003,-48.939003,-42.939003,-26.939003,-14.939003,-5.939003,-14.939003,-22.939003,-30.939003,-37.939003,-40.939003,-35.939003,-32.939003,-31.939003,-34.939003,-35.939003,-32.939003,-26.939003,-24.939003,-29.939003,-31.939003,-34.939003,-35.939003,-36.939003,-37.939003,-34.939003,-32.939003,-35.939003,-38.939003,-40.939003,-44.939003,-46.939003,-45.939003,-48.939003,-50.939003,-51.939003,-51.939003,-52.939003,-56.939003,-59.939003,-61.939003,-64.939,-64.939,-62.939003,-66.939,-68.939,-49.939003,-38.939003,-33.939003,-26.939003,-21.939003,-20.939003,-17.939003,-14.939003,-13.939003,-10.939003,-7.939003,-3.939003,-2.939003,-1.939003,0.06099701,5.060997,16.060997,23.060997,26.060997,35.060997,39.060997,32.060997,22.060997,13.060997,19.060997,21.060997,20.060997,19.060997,21.060997,27.060997,30.060997,30.060997,27.060997,26.060997,29.060997,30.060997,32.060997,32.060997,33.060997,33.060997,35.060997,37.060997,38.060997,37.060997,37.060997,36.060997,38.060997,40.060997,40.060997,38.060997,35.060997,38.060997,40.060997,38.060997,39.060997,41.060997,43.060997,44.060997,44.060997,43.060997,43.060997,43.060997,44.060997,45.060997,44.060997,42.060997,39.060997,43.060997,46.060997,47.060997,47.060997,47.060997,52.060997,55.060997,56.060997,56.060997,55.060997,50.060997,49.060997,47.060997,41.060997,37.060997,37.060997,35.060997,34.060997,35.060997,34.060997,33.060997,33.060997,33.060997,32.060997,33.060997,33.060997,31.060997,29.060997,28.060997,27.060997,27.060997,27.060997,29.060997,29.060997,26.060997,25.060997,24.060997,21.060997,19.060997,18.060997,15.060997,12.060997,11.060997,11.060997,13.060997,11.060997,9.060997,7.060997,4.060997,3.060997,1.060997,2.060997,3.060997,2.060997,-0.939003,-4.939003,-3.939003,-3.939003,-6.939003,-5.939003,-4.939003,-4.939003,-2.939003,-0.939003,-8.939003,-16.939003,-25.939003,-29.939003,-30.939003,-29.939003,-31.939003,-38.939003,-41.939003,-48.939003,-65.939,-74.939,-75.939,-54.939003,-46.939003,-51.939003,-53.939003,-54.939003,-55.939003,-57.939003,-58.939003,-57.939003,-58.939003,-60.939003,-63.939003,-66.939,-69.939,-69.939,-68.939,-70.939,-72.939,-75.939,-75.939,-75.939,-76.939,-78.939,-78.939,-62.939003,-51.939003,-44.939003,-44.939003,-45.939003,-43.939003,-42.939003,-42.939003,-61.939003,-48.939003,-6.939003,1.060997,-3.939003,-28.939003,-25.939003,-14.939003,-14.939003,-12.939003,-11.939003,-13.939003,-19.939003,-34.939003,-29.939003,-18.939003,-20.939003,-24.939003,-32.939003,-34.939003,-36.939003,-35.939003,-36.939003,-37.939003,-40.939003,-42.939003,-43.939003,-41.939003,-40.939003,-36.939003,-35.939003,-35.939003,-37.939003,-36.939003,-33.939003,-34.939003,-35.939003,-34.939003,-34.939003,-34.939003,-36.939003,-36.939003,-35.939003,-36.939003,-35.939003,-32.939003,-30.939003,-28.939003,-25.939003,-24.939003,-25.939003,-25.939003,-27.939003,-36.939003,-31.939003,-16.939003,-1.939003,-1.939003,-16.939003,-9.939003,-6.939003,-16.939003,-17.939003,-13.939003,-5.939003,-2.939003,-3.939003,0.06099701,4.060997,4.060997,-8.939003,-25.939003,-22.939003,-9.939003,12.060997,10.060997,7.060997,12.060997,13.060997,12.060997,11.060997,17.060997,30.060997,-5.939003,-47.939003,-84.939,-93.939,-90.939,-88.939,-88.939,-92.939,-92.939,-93.939,-94.939,-97.939,-100.939,-101.939,-102.939,-103.939,-102.939,-94.939,-69.939,-76.939,-96.939,-98.939,-99.939,-99.939,-98.939,-98.939,-96.939,-96.939,-96.939,-95.939,-93.939,-92.939,-91.939,-91.939,-92.939,-91.939,-90.939,-88.939,-88.939,-89.939,-43.939003,-1.939003,14.060997,0.06099701,-28.939003,-62.939003,-81.939,-84.939,-85.939,-85.939,-81.939,-81.939,-82.939,-83.939,-82.939,-80.939,-78.939,-77.939,-78.939,-79.939,-81.939,-81.939,-79.939,-76.939,-76.939,-75.939,-74.939,-76.939,-79.939,-80.939,-82.939,-81.939,-78.939,-76.939,-74.939,-76.939,-78.939,-76.939,-73.939,-69.939,-70.939,-73.939,-77.939,-64.939,-46.939003,-33.939003,-28.939003,-32.939003,-49.939003,-61.939003,-63.939003,-64.939,-66.939,-66.939,-68.939,-68.939,-68.939,-67.939,-67.939,-66.939,-66.939,-62.939003,-60.939003,-59.939003,-61.939003,-63.939003,-65.939,-67.939,-69.939,-66.939,-67.939,-70.939,-67.939,-65.939,-67.939,-65.939,-62.939003,-62.939003,-64.939,-68.939,-71.939,-73.939,-75.939,-75.939,-74.939,-76.939,-72.939,-60.939003,-47.939003,-37.939003,-14.939003,-16.939003,-17.939003,-16.939003,-19.939003,-24.939003,-31.939003,-38.939003,-39.939003,-41.939003,-41.939003,-44.939003,-44.939003,-39.939003,-42.939003,-45.939003,-47.939003,-49.939003,-52.939003,-55.939003,-63.939003,-73.939,-80.939,-85.939,-88.939,-90.939,-92.939,-89.939,-83.939,-74.939,-63.939003,-55.939003,-53.939003,-53.939003,-55.939003,-61.939003,-65.939,-67.939,-57.939003,-42.939003,-22.939003,-41.939003,-74.939,-82.939,-87.939,-89.939,-90.939,-91.939,-95.939,-98.939,-99.939,-97.939,-97.939,-101.939,-99.939,-97.939,-99.939,-100.939,-102.939,-103.939,-102.939,-101.939,-102.939,-102.939,-103.939,-101.939,-98.939,-98.939,-96.939,-93.939,-94.939,-95.939,-96.939,-94.939,-91.939,-95.939,-85.939,-61.939003,20.060997,76.061,58.060997,63.060997,71.061,9.060997,-43.939003,-85.939,-89.939,-87.939,-86.939,-85.939,-85.939,-83.939,-79.939,-75.939,-70.939,-66.939,-65.939,-64.939,-62.939003,-60.939003,-55.939003,-47.939003,-43.939003,-40.939003,-39.939003,-36.939003,-34.939003,-32.939003,-30.939003,-25.939003,-21.939003,-19.939003,-21.939003,-21.939003,-21.939003,-19.939003,-18.939003,-15.939003,-17.939003,-20.939003,-29.939003,-31.939003,-32.939003,-30.939003,-32.939003,-41.939003,-49.939003,-56.939003,-57.939003,-58.939003,-60.939003,-66.939,-70.939,-74.939,-80.939,-86.939,-86.939,-86.939,-86.939,-95.939,-99.939,-99.939,-101.939,-72.939,34.060997,69.061,76.061,92.061,72.061,17.060997,-52.939003,-103.939,-101.939,-99.939,-98.939,-99.939,-99.939,-97.939,-98.939,-98.939,-97.939,-96.939,-94.939,-93.939,-91.939,-89.939,-89.939,-89.939,-89.939,-87.939,-85.939,-85.939,-83.939,-82.939,-81.939,-78.939,-73.939,-68.939,-64.939,-62.939003,-59.939003,-57.939003,-55.939003,-51.939003,-45.939003,-42.939003,-40.939003,-39.939003,-36.939003,-32.939003,-27.939003,-22.939003,-20.939003,-19.939003,-18.939003,-20.939003,-19.939003,-16.939003,-17.939003,-18.939003,-26.939003,-27.939003,-26.939003,-28.939003,-28.939003,-27.939003,-28.939003,-26.939003,-20.939003,-11.939003,-0.939003,15.060997,23.060997,25.060997,27.060997,28.060997,32.060997,36.060997,41.060997,43.060997,55.060997,74.061,45.060997,20.060997,18.060997,28.060997,43.060997,51.060997,63.060997,80.061,71.061,62.060997,59.060997,60.060997,61.060997,62.060997,62.060997,63.060997,61.060997,59.060997,56.060997,54.060997,53.060997,54.060997,56.060997,60.060997,26.060997,9.060997,43.060997,49.060997,43.060997,43.060997,41.060997,38.060997,6.060997,-10.939003,12.060997,23.060997,26.060997,23.060997,8.060997,-20.939003,-57.939003,-79.939,-64.939,-59.939003,-58.939003,-56.939003,-53.939003,-52.939003,-44.939003,-39.939003,-37.939003,-37.939003,-37.939003,-34.939003,-32.939003,-31.939003,-26.939003,-23.939003,-22.939003,-20.939003,-17.939003,-18.939003,-18.939003,-16.939003,-16.939003,-17.939003,-21.939003,-19.939003,-16.939003,-22.939003,-25.939003,-24.939003,-23.939003,-25.939003,-34.939003,-44.939003,-51.939003,-32.939003,-9.939003,17.060997,28.060997,32.060997,24.060997,17.060997,6.060997,-55.939003,-88.939,-93.939,-87.939,-85.939,-97.939,-88.939,-61.939003,25.060997,59.060997,39.060997,46.060997,24.060997,-73.939,-103.939,-103.939,-102.939,-101.939,-100.939,-99.939,-98.939,-98.939,-97.939,-96.939,-95.939,-95.939,-94.939,-93.939,-92.939,-91.939,-90.939,-89.939,-89.939,-88.939,-87.939,-86.939,-86.939,-84.939,-84.939,-84.939,-81.939,-72.939,-54.939003,-57.939003,-61.939003,-59.939003,-51.939003,-42.939003,-54.939003,-56.939003,-48.939003,-42.939003,-37.939003,-36.939003,-35.939003,-35.939003,-29.939003,-24.939003,-18.939003,-17.939003,-19.939003,-31.939003,-39.939003,-44.939003,-44.939003,-40.939003,-29.939003,-24.939003,-19.939003,-17.939003,-16.939003,-17.939003,-24.939003,-28.939003,-29.939003,-26.939003,-25.939003,-31.939003,-36.939003,-40.939003,-46.939003,-52.939003,-58.939003,-56.939003,-55.939003,-59.939003,-63.939003,-67.939,-75.939,-79.939,-81.939,-87.939,-91.939,-92.939,-92.939,-94.939,-96.939,-98.939,-102.939,-103.939,-102.939,-102.939,-99.939,-91.939,-58.939003,-37.939003,-29.939003,-25.939003,-21.939003,-19.939003,-13.939003,-7.939003,-6.939003,-3.939003,-0.939003,0.06099701,-0.939003,-2.939003,1.060997,7.060997,18.060997,24.060997,25.060997,31.060997,33.060997,24.060997,21.060997,20.060997,22.060997,24.060997,23.060997,22.060997,23.060997,28.060997,31.060997,32.060997,28.060997,27.060997,30.060997,31.060997,32.060997,31.060997,32.060997,34.060997,36.060997,37.060997,37.060997,36.060997,35.060997,35.060997,37.060997,38.060997,40.060997,39.060997,37.060997,38.060997,40.060997,39.060997,39.060997,39.060997,42.060997,44.060997,44.060997,43.060997,43.060997,45.060997,45.060997,46.060997,43.060997,42.060997,40.060997,44.060997,46.060997,44.060997,45.060997,47.060997,51.060997,54.060997,55.060997,54.060997,52.060997,48.060997,48.060997,47.060997,41.060997,38.060997,38.060997,36.060997,35.060997,37.060997,36.060997,35.060997,34.060997,33.060997,33.060997,35.060997,35.060997,32.060997,30.060997,28.060997,29.060997,30.060997,34.060997,32.060997,31.060997,27.060997,27.060997,26.060997,25.060997,23.060997,22.060997,19.060997,17.060997,16.060997,15.060997,16.060997,14.060997,13.060997,12.060997,9.060997,6.060997,5.060997,5.060997,5.060997,4.060997,1.060997,-2.939003,-3.939003,-2.939003,-4.939003,-4.939003,-4.939003,-8.939003,-8.939003,-3.939003,-10.939003,-18.939003,-25.939003,-31.939003,-35.939003,-33.939003,-38.939003,-49.939003,-50.939003,-55.939003,-71.939,-83.939,-87.939,-65.939,-57.939003,-63.939003,-69.939,-73.939,-74.939,-76.939,-79.939,-84.939,-88.939,-89.939,-91.939,-93.939,-98.939,-99.939,-100.939,-96.939,-95.939,-97.939,-96.939,-95.939,-95.939,-94.939,-91.939,-68.939,-51.939003,-40.939003,-41.939003,-42.939003,-41.939003,-41.939003,-42.939003,-58.939003,-48.939003,-12.939003,-10.939003,-17.939003,-39.939003,-40.939003,-34.939003,-34.939003,-34.939003,-36.939003,-37.939003,-39.939003,-46.939003,-47.939003,-46.939003,-54.939003,-59.939003,-64.939,-64.939,-63.939003,-60.939003,-58.939003,-56.939003,-57.939003,-58.939003,-57.939003,-50.939003,-44.939003,-37.939003,-36.939003,-37.939003,-39.939003,-37.939003,-33.939003,-36.939003,-40.939003,-38.939003,-38.939003,-38.939003,-36.939003,-35.939003,-35.939003,-34.939003,-33.939003,-29.939003,-29.939003,-30.939003,-26.939003,-25.939003,-28.939003,-28.939003,-29.939003,-35.939003,-36.939003,-33.939003,-35.939003,-38.939003,-44.939003,-40.939003,-39.939003,-45.939003,-46.939003,-43.939003,-38.939003,-37.939003,-41.939003,-36.939003,-31.939003,-25.939003,-29.939003,-38.939003,-35.939003,-30.939003,-23.939003,-24.939003,-25.939003,-24.939003,-20.939003,-16.939003,-17.939003,-12.939003,-1.939003,-24.939003,-50.939003,-75.939,-78.939,-74.939,-75.939,-77.939,-80.939,-81.939,-83.939,-84.939,-86.939,-88.939,-89.939,-91.939,-92.939,-92.939,-86.939,-65.939,-72.939,-92.939,-92.939,-93.939,-93.939,-93.939,-92.939,-92.939,-93.939,-93.939,-93.939,-93.939,-94.939,-93.939,-93.939,-93.939,-94.939,-95.939,-94.939,-94.939,-94.939,-35.939003,16.060997,36.060997,12.060997,-28.939003,-71.939,-94.939,-96.939,-97.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-100.939,-98.939,-96.939,-96.939,-96.939,-94.939,-96.939,-100.939,-100.939,-100.939,-98.939,-93.939,-92.939,-94.939,-95.939,-95.939,-91.939,-89.939,-88.939,-87.939,-87.939,-91.939,-69.939,-37.939003,-12.939003,-5.939003,-17.939003,-50.939003,-75.939,-74.939,-75.939,-76.939,-75.939,-75.939,-75.939,-73.939,-71.939,-70.939,-70.939,-68.939,-64.939,-62.939003,-62.939003,-63.939003,-64.939,-62.939003,-62.939003,-64.939,-62.939003,-62.939003,-64.939,-62.939003,-59.939003,-58.939003,-58.939003,-59.939003,-58.939003,-58.939003,-58.939003,-60.939003,-61.939003,-63.939003,-62.939003,-61.939003,-63.939003,-62.939003,-58.939003,-54.939003,-50.939003,-13.939003,-18.939003,-25.939003,-25.939003,-28.939003,-35.939003,-40.939003,-44.939003,-48.939003,-48.939003,-46.939003,-49.939003,-49.939003,-47.939003,-46.939003,-44.939003,-46.939003,-51.939003,-59.939003,-73.939,-84.939,-95.939,-96.939,-96.939,-97.939,-97.939,-98.939,-97.939,-93.939,-89.939,-88.939,-89.939,-88.939,-88.939,-90.939,-92.939,-93.939,-95.939,-80.939,-57.939003,-18.939003,-43.939003,-91.939,-97.939,-99.939,-100.939,-100.939,-100.939,-99.939,-100.939,-102.939,-98.939,-98.939,-100.939,-100.939,-102.939,-100.939,-101.939,-101.939,-102.939,-103.939,-103.939,-101.939,-101.939,-102.939,-102.939,-100.939,-99.939,-99.939,-98.939,-96.939,-93.939,-92.939,-89.939,-86.939,-87.939,-78.939,-61.939003,-2.939003,36.060997,23.060997,17.060997,11.060997,-15.939003,-34.939003,-44.939003,-40.939003,-36.939003,-29.939003,-29.939003,-30.939003,-24.939003,-21.939003,-20.939003,-15.939003,-11.939003,-10.939003,-12.939003,-17.939003,-18.939003,-18.939003,-18.939003,-23.939003,-27.939003,-29.939003,-29.939003,-30.939003,-37.939003,-42.939003,-45.939003,-47.939003,-50.939003,-59.939003,-62.939003,-63.939003,-67.939,-70.939,-72.939,-75.939,-79.939,-84.939,-86.939,-87.939,-86.939,-87.939,-89.939,-91.939,-92.939,-92.939,-93.939,-93.939,-94.939,-95.939,-96.939,-98.939,-99.939,-99.939,-99.939,-99.939,-100.939,-101.939,-102.939,-103.939,-73.939,25.060997,64.061,77.061,88.061,71.061,28.060997,-47.939003,-101.939,-98.939,-97.939,-96.939,-91.939,-87.939,-86.939,-84.939,-81.939,-76.939,-71.939,-66.939,-61.939003,-56.939003,-50.939003,-49.939003,-48.939003,-41.939003,-34.939003,-27.939003,-24.939003,-22.939003,-20.939003,-20.939003,-19.939003,-18.939003,-16.939003,-13.939003,-10.939003,-11.939003,-14.939003,-21.939003,-26.939003,-27.939003,-25.939003,-24.939003,-29.939003,-35.939003,-40.939003,-42.939003,-45.939003,-49.939003,-52.939003,-56.939003,-64.939,-68.939,-68.939,-68.939,-70.939,-75.939,-80.939,-85.939,-86.939,-83.939,-79.939,-50.939003,-28.939003,-25.939003,-6.939003,20.060997,53.060997,65.061,54.060997,55.060997,56.060997,55.060997,55.060997,53.060997,55.060997,61.060997,73.061,51.060997,29.060997,15.060997,26.060997,45.060997,50.060997,62.060997,81.061,69.061,57.060997,53.060997,55.060997,58.060997,53.060997,50.060997,48.060997,47.060997,44.060997,36.060997,32.060997,29.060997,25.060997,23.060997,25.060997,4.060997,-8.939003,-0.939003,-1.939003,-5.939003,-11.939003,-19.939003,-26.939003,-21.939003,-9.939003,17.060997,25.060997,26.060997,24.060997,6.060997,-25.939003,-36.939003,-34.939003,-9.939003,-7.939003,-13.939003,-13.939003,-14.939003,-17.939003,-19.939003,-21.939003,-24.939003,-26.939003,-28.939003,-32.939003,-36.939003,-39.939003,-42.939003,-46.939003,-50.939003,-52.939003,-55.939003,-62.939003,-67.939,-71.939,-73.939,-75.939,-77.939,-80.939,-82.939,-84.939,-84.939,-85.939,-84.939,-85.939,-87.939,-89.939,-88.939,-53.939003,-16.939003,24.060997,33.060997,34.060997,25.060997,19.060997,7.060997,-59.939003,-95.939,-100.939,-99.939,-99.939,-102.939,-93.939,-72.939,24.060997,60.060997,36.060997,39.060997,16.060997,-73.939,-101.939,-102.939,-98.939,-93.939,-89.939,-86.939,-83.939,-81.939,-78.939,-74.939,-70.939,-66.939,-64.939,-60.939003,-56.939003,-52.939003,-48.939003,-44.939003,-40.939003,-37.939003,-34.939003,-31.939003,-27.939003,-23.939003,-23.939003,-25.939003,-21.939003,-20.939003,-23.939003,-19.939003,-16.939003,-14.939003,-16.939003,-20.939003,-20.939003,-21.939003,-20.939003,-23.939003,-26.939003,-27.939003,-27.939003,-26.939003,-34.939003,-37.939003,-37.939003,-46.939003,-48.939003,-25.939003,-5.939003,12.060997,20.060997,6.060997,-28.939003,-61.939003,-83.939,-81.939,-81.939,-83.939,-85.939,-86.939,-86.939,-85.939,-85.939,-86.939,-87.939,-89.939,-90.939,-91.939,-93.939,-92.939,-92.939,-93.939,-94.939,-95.939,-96.939,-98.939,-98.939,-99.939,-100.939,-100.939,-101.939,-101.939,-101.939,-102.939,-103.939,-101.939,-101.939,-101.939,-87.939,-66.939,-43.939003,-29.939003,-26.939003,-20.939003,-14.939003,-10.939003,-2.939003,5.060997,6.060997,9.060997,14.060997,11.060997,9.060997,5.060997,8.060997,15.060997,26.060997,31.060997,30.060997,29.060997,29.060997,31.060997,33.060997,34.060997,31.060997,29.060997,30.060997,32.060997,33.060997,33.060997,32.060997,32.060997,32.060997,33.060997,36.060997,36.060997,34.060997,29.060997,32.060997,37.060997,38.060997,37.060997,33.060997,34.060997,35.060997,35.060997,35.060997,37.060997,39.060997,41.060997,42.060997,40.060997,40.060997,44.060997,44.060997,44.060997,44.060997,44.060997,46.060997,47.060997,47.060997,46.060997,44.060997,41.060997,42.060997,43.060997,48.060997,48.060997,47.060997,45.060997,46.060997,47.060997,47.060997,49.060997,51.060997,50.060997,50.060997,50.060997,52.060997,53.060997,49.060997,46.060997,44.060997,44.060997,43.060997,42.060997,42.060997,43.060997,39.060997,38.060997,40.060997,41.060997,41.060997,41.060997,42.060997,43.060997,40.060997,41.060997,45.060997,41.060997,38.060997,39.060997,39.060997,39.060997,40.060997,38.060997,34.060997,34.060997,33.060997,34.060997,33.060997,31.060997,32.060997,32.060997,33.060997,29.060997,25.060997,24.060997,20.060997,18.060997,17.060997,16.060997,15.060997,10.060997,5.060997,4.060997,-0.939003,-5.939003,-7.939003,-10.939003,-14.939003,-28.939003,-41.939003,-52.939003,-60.939003,-66.939,-65.939,-64.939,-64.939,-62.939003,-66.939,-90.939,-98.939,-99.939,-94.939,-92.939,-94.939,-95.939,-96.939,-96.939,-97.939,-98.939,-99.939,-99.939,-100.939,-100.939,-98.939,-93.939,-91.939,-89.939,-83.939,-78.939,-75.939,-71.939,-69.939,-69.939,-66.939,-61.939003,-52.939003,-47.939003,-47.939003,-44.939003,-42.939003,-42.939003,-44.939003,-46.939003,-47.939003,-49.939003,-53.939003,-53.939003,-52.939003,-51.939003,-55.939003,-59.939003,-56.939003,-53.939003,-50.939003,-46.939003,-44.939003,-45.939003,-43.939003,-40.939003,-36.939003,-36.939003,-41.939003,-37.939003,-33.939003,-28.939003,-23.939003,-18.939003,-11.939003,-6.939003,-3.939003,-1.939003,4.060997,17.060997,0.06099701,-28.939003,-49.939003,-56.939003,-48.939003,-50.939003,-52.939003,-53.939003,-53.939003,-50.939003,-41.939003,-36.939003,-36.939003,-35.939003,-33.939003,-25.939003,-26.939003,-29.939003,-24.939003,-24.939003,-29.939003,-29.939003,-32.939003,-40.939003,-27.939003,-8.939003,-10.939003,-15.939003,-22.939003,-21.939003,-23.939003,-30.939003,-35.939003,-37.939003,-32.939003,-31.939003,-35.939003,-34.939003,-32.939003,-31.939003,-34.939003,-40.939003,-40.939003,-40.939003,-38.939003,-38.939003,-40.939003,-43.939003,-43.939003,-41.939003,-44.939003,-43.939003,-40.939003,-39.939003,-43.939003,-54.939003,-53.939003,-49.939003,-52.939003,-52.939003,-51.939003,-50.939003,-51.939003,-53.939003,-54.939003,-53.939003,-54.939003,-54.939003,-53.939003,-55.939003,-54.939003,-50.939003,-51.939003,-53.939003,-56.939003,-59.939003,-62.939003,-59.939003,-57.939003,-57.939003,-59.939003,-61.939003,-58.939003,-59.939003,-62.939003,-61.939003,-59.939003,-61.939003,-64.939,-69.939,-66.939,-64.939,-65.939,-43.939003,-23.939003,-12.939003,-22.939003,-42.939003,-63.939003,-74.939,-75.939,-78.939,-81.939,-81.939,-82.939,-83.939,-82.939,-82.939,-84.939,-85.939,-85.939,-84.939,-86.939,-88.939,-88.939,-89.939,-89.939,-91.939,-92.939,-92.939,-94.939,-98.939,-98.939,-98.939,-99.939,-98.939,-97.939,-100.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-100.939,-69.939,-26.939003,9.060997,19.060997,1.060997,-54.939003,-96.939,-96.939,-96.939,-97.939,-97.939,-97.939,-96.939,-96.939,-96.939,-95.939,-95.939,-95.939,-94.939,-94.939,-93.939,-94.939,-94.939,-94.939,-94.939,-94.939,-93.939,-93.939,-94.939,-93.939,-92.939,-89.939,-80.939,-70.939,-78.939,-84.939,-90.939,-90.939,-90.939,-86.939,-87.939,-89.939,-88.939,-77.939,-53.939003,-40.939003,-32.939003,-20.939003,-25.939003,-31.939003,-32.939003,-34.939003,-38.939003,-42.939003,-47.939003,-50.939003,-50.939003,-50.939003,-51.939003,-51.939003,-50.939003,-52.939003,-54.939003,-60.939003,-66.939,-75.939,-86.939,-95.939,-101.939,-100.939,-99.939,-98.939,-96.939,-96.939,-95.939,-95.939,-95.939,-97.939,-99.939,-98.939,-98.939,-98.939,-99.939,-100.939,-102.939,-90.939,-67.939,-18.939003,-40.939003,-89.939,-98.939,-103.939,-103.939,-102.939,-102.939,-100.939,-100.939,-100.939,-94.939,-92.939,-94.939,-94.939,-92.939,-88.939,-87.939,-85.939,-83.939,-81.939,-79.939,-77.939,-75.939,-73.939,-70.939,-66.939,-66.939,-65.939,-64.939,-60.939003,-56.939003,-53.939003,-51.939003,-50.939003,-50.939003,-45.939003,-36.939003,-16.939003,-2.939003,-6.939003,-8.939003,-10.939003,-19.939003,-28.939003,-33.939003,-34.939003,-33.939003,-28.939003,-31.939003,-35.939003,-33.939003,-33.939003,-35.939003,-35.939003,-35.939003,-36.939003,-40.939003,-44.939003,-47.939003,-49.939003,-51.939003,-55.939003,-58.939003,-60.939003,-60.939003,-60.939003,-67.939,-70.939,-72.939,-75.939,-77.939,-82.939,-85.939,-87.939,-90.939,-93.939,-94.939,-96.939,-99.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-99.939,-99.939,-99.939,-99.939,-97.939,-94.939,-91.939,-88.939,-86.939,-83.939,-61.939003,3.060997,28.060997,34.060997,36.060997,25.060997,2.060997,-34.939003,-61.939003,-55.939003,-53.939003,-52.939003,-49.939003,-48.939003,-49.939003,-47.939003,-44.939003,-39.939003,-37.939003,-35.939003,-33.939003,-32.939003,-32.939003,-34.939003,-36.939003,-34.939003,-32.939003,-31.939003,-30.939003,-31.939003,-30.939003,-33.939003,-36.939003,-38.939003,-40.939003,-41.939003,-41.939003,-42.939003,-45.939003,-51.939003,-56.939003,-57.939003,-56.939003,-56.939003,-59.939003,-63.939003,-68.939,-71.939,-74.939,-76.939,-79.939,-83.939,-88.939,-91.939,-92.939,-92.939,-92.939,-94.939,-97.939,-101.939,-103.939,-101.939,-98.939,-61.939003,-31.939003,-27.939003,-8.939003,16.060997,55.060997,67.061,53.060997,53.060997,54.060997,52.060997,49.060997,45.060997,42.060997,43.060997,48.060997,29.060997,10.060997,-0.939003,7.060997,20.060997,18.060997,23.060997,35.060997,25.060997,16.060997,13.060997,11.060997,9.060997,7.060997,5.060997,5.060997,3.060997,2.060997,-2.939003,-5.939003,-7.939003,-8.939003,-8.939003,-7.939003,-15.939003,-20.939003,-16.939003,-17.939003,-18.939003,-19.939003,-21.939003,-25.939003,-20.939003,-8.939003,17.060997,25.060997,26.060997,22.060997,7.060997,-20.939003,-44.939003,-56.939003,-36.939003,-36.939003,-43.939003,-44.939003,-46.939003,-49.939003,-51.939003,-53.939003,-56.939003,-58.939003,-60.939003,-62.939003,-65.939,-68.939,-71.939,-75.939,-77.939,-79.939,-82.939,-87.939,-91.939,-94.939,-96.939,-97.939,-98.939,-100.939,-102.939,-102.939,-102.939,-100.939,-102.939,-103.939,-101.939,-102.939,-99.939,-61.939003,-20.939003,22.060997,31.060997,33.060997,25.060997,16.060997,1.060997,-56.939003,-86.939,-84.939,-81.939,-77.939,-75.939,-69.939,-58.939003,-4.939003,14.060997,-3.939003,-4.939003,-14.939003,-52.939003,-61.939003,-57.939003,-55.939003,-50.939003,-45.939003,-45.939003,-44.939003,-42.939003,-40.939003,-39.939003,-38.939003,-38.939003,-37.939003,-36.939003,-35.939003,-34.939003,-32.939003,-30.939003,-29.939003,-30.939003,-34.939003,-33.939003,-32.939003,-30.939003,-32.939003,-35.939003,-36.939003,-33.939003,-28.939003,-34.939003,-40.939003,-43.939003,-41.939003,-39.939003,-46.939003,-51.939003,-52.939003,-55.939003,-58.939003,-59.939003,-59.939003,-59.939003,-65.939,-68.939,-68.939,-76.939,-68.939,-21.939003,21.060997,59.060997,59.060997,32.060997,-20.939003,-69.939,-102.939,-99.939,-100.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-100.939,-100.939,-100.939,-98.939,-96.939,-95.939,-94.939,-92.939,-89.939,-88.939,-87.939,-84.939,-81.939,-78.939,-76.939,-75.939,-75.939,-73.939,-70.939,-68.939,-67.939,-68.939,-65.939,-62.939003,-60.939003,-51.939003,-40.939003,-30.939003,-26.939003,-32.939003,-32.939003,-29.939003,-24.939003,-15.939003,-8.939003,-9.939003,-7.939003,-4.939003,-3.939003,-3.939003,-8.939003,-6.939003,-0.939003,7.060997,12.060997,18.060997,17.060997,18.060997,24.060997,27.060997,28.060997,26.060997,26.060997,29.060997,32.060997,34.060997,35.060997,36.060997,36.060997,39.060997,42.060997,47.060997,46.060997,44.060997,39.060997,40.060997,43.060997,46.060997,46.060997,43.060997,43.060997,43.060997,45.060997,47.060997,49.060997,47.060997,47.060997,47.060997,47.060997,49.060997,51.060997,53.060997,54.060997,56.060997,58.060997,60.060997,59.060997,57.060997,56.060997,54.060997,52.060997,54.060997,53.060997,50.060997,51.060997,54.060997,60.060997,62.060997,61.060997,60.060997,59.060997,58.060997,55.060997,53.060997,51.060997,53.060997,55.060997,55.060997,54.060997,53.060997,53.060997,52.060997,52.060997,50.060997,49.060997,45.060997,44.060997,45.060997,45.060997,44.060997,44.060997,41.060997,37.060997,32.060997,32.060997,38.060997,36.060997,33.060997,34.060997,33.060997,30.060997,27.060997,24.060997,21.060997,19.060997,17.060997,15.060997,12.060997,9.060997,6.060997,3.060997,3.060997,0.06099701,-4.939003,-8.939003,-12.939003,-17.939003,-17.939003,-19.939003,-22.939003,-30.939003,-36.939003,-37.939003,-42.939003,-48.939003,-49.939003,-51.939003,-55.939003,-64.939,-72.939,-78.939,-83.939,-85.939,-75.939,-62.939003,-46.939003,-41.939003,-48.939003,-76.939,-84.939,-83.939,-82.939,-81.939,-78.939,-77.939,-76.939,-75.939,-75.939,-76.939,-74.939,-73.939,-73.939,-73.939,-71.939,-64.939,-62.939003,-61.939003,-58.939003,-55.939003,-52.939003,-51.939003,-51.939003,-51.939003,-50.939003,-48.939003,-47.939003,-47.939003,-48.939003,-46.939003,-45.939003,-46.939003,-46.939003,-46.939003,-45.939003,-42.939003,-38.939003,-36.939003,-36.939003,-38.939003,-38.939003,-35.939003,-33.939003,-30.939003,-25.939003,-24.939003,-27.939003,-36.939003,-30.939003,-16.939003,-4.939003,-2.939003,-11.939003,-9.939003,-7.939003,-5.939003,-6.939003,-5.939003,0.06099701,4.060997,5.060997,2.060997,2.060997,11.060997,-4.939003,-30.939003,-50.939003,-56.939003,-47.939003,-47.939003,-47.939003,-47.939003,-50.939003,-52.939003,-43.939003,-38.939003,-35.939003,-35.939003,-33.939003,-24.939003,-24.939003,-27.939003,-24.939003,-24.939003,-27.939003,-27.939003,-31.939003,-42.939003,-18.939003,17.060997,22.060997,17.060997,2.060997,5.060997,4.060997,-5.939003,-9.939003,-9.939003,0.06099701,3.060997,-0.939003,-1.939003,-2.939003,-1.939003,-17.939003,-35.939003,-24.939003,-12.939003,0.06099701,-2.939003,-6.939003,-6.939003,-7.939003,-8.939003,-9.939003,-7.939003,-5.939003,-27.939003,-50.939003,-67.939,-70.939,-68.939,-68.939,-67.939,-65.939,-65.939,-67.939,-69.939,-68.939,-65.939,-66.939,-66.939,-67.939,-64.939,-59.939003,-53.939003,-56.939003,-62.939003,-62.939003,-63.939003,-64.939,-64.939,-64.939,-62.939003,-60.939003,-58.939003,-58.939003,-58.939003,-59.939003,-58.939003,-58.939003,-61.939003,-62.939003,-64.939,-62.939003,-60.939003,-59.939003,-44.939003,-31.939003,-28.939003,-34.939003,-45.939003,-57.939003,-63.939003,-63.939003,-65.939,-67.939,-65.939,-66.939,-67.939,-65.939,-64.939,-66.939,-66.939,-65.939,-63.939003,-64.939,-67.939,-67.939,-68.939,-69.939,-69.939,-70.939,-71.939,-73.939,-75.939,-75.939,-75.939,-76.939,-77.939,-78.939,-80.939,-79.939,-78.939,-80.939,-82.939,-82.939,-83.939,-84.939,-85.939,-63.939003,-33.939003,-11.939003,-5.939003,-14.939003,-54.939003,-85.939,-86.939,-87.939,-87.939,-86.939,-86.939,-88.939,-90.939,-91.939,-92.939,-91.939,-90.939,-92.939,-94.939,-93.939,-94.939,-95.939,-96.939,-96.939,-96.939,-97.939,-98.939,-99.939,-100.939,-99.939,-93.939,-83.939,-74.939,-84.939,-92.939,-99.939,-100.939,-100.939,-96.939,-97.939,-99.939,-99.939,-81.939,-46.939003,-30.939003,-21.939003,-31.939003,-33.939003,-36.939003,-38.939003,-39.939003,-38.939003,-43.939003,-48.939003,-50.939003,-52.939003,-54.939003,-53.939003,-52.939003,-52.939003,-60.939003,-69.939,-78.939,-86.939,-93.939,-97.939,-100.939,-100.939,-98.939,-97.939,-96.939,-93.939,-90.939,-91.939,-93.939,-97.939,-99.939,-100.939,-97.939,-95.939,-94.939,-95.939,-96.939,-98.939,-92.939,-72.939,-19.939003,-34.939003,-76.939,-91.939,-99.939,-97.939,-96.939,-96.939,-95.939,-93.939,-91.939,-83.939,-79.939,-81.939,-79.939,-74.939,-67.939,-64.939,-61.939003,-55.939003,-51.939003,-47.939003,-46.939003,-42.939003,-35.939003,-30.939003,-24.939003,-24.939003,-23.939003,-22.939003,-17.939003,-13.939003,-8.939003,-9.939003,-11.939003,-11.939003,-10.939003,-9.939003,-22.939003,-30.939003,-25.939003,-17.939003,-9.939003,-11.939003,-22.939003,-39.939003,-48.939003,-53.939003,-51.939003,-58.939003,-65.939,-68.939,-73.939,-79.939,-83.939,-88.939,-93.939,-95.939,-96.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-102.939,-102.939,-100.939,-100.939,-100.939,-98.939,-99.939,-100.939,-101.939,-101.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-98.939,-98.939,-98.939,-97.939,-97.939,-97.939,-94.939,-89.939,-88.939,-88.939,-87.939,-83.939,-77.939,-71.939,-66.939,-61.939003,-54.939003,-43.939003,-21.939003,-16.939003,-19.939003,-25.939003,-30.939003,-33.939003,-22.939003,-12.939003,-6.939003,-3.939003,-0.939003,-2.939003,-6.939003,-11.939003,-10.939003,-8.939003,-6.939003,-9.939003,-12.939003,-16.939003,-20.939003,-27.939003,-35.939003,-41.939003,-46.939003,-52.939003,-60.939003,-65.939,-68.939,-69.939,-75.939,-80.939,-85.939,-92.939,-97.939,-100.939,-102.939,-101.939,-102.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-101.939,-102.939,-102.939,-100.939,-100.939,-101.939,-102.939,-101.939,-100.939,-99.939,-98.939,-97.939,-96.939,-96.939,-98.939,-97.939,-96.939,-63.939003,-35.939003,-26.939003,-13.939003,1.060997,37.060997,48.060997,35.060997,35.060997,34.060997,32.060997,28.060997,25.060997,18.060997,13.060997,12.060997,-3.939003,-17.939003,-19.939003,-16.939003,-12.939003,-21.939003,-24.939003,-22.939003,-27.939003,-32.939003,-33.939003,-39.939003,-46.939003,-42.939003,-40.939003,-39.939003,-41.939003,-41.939003,-38.939003,-37.939003,-38.939003,-34.939003,-31.939003,-30.939003,-29.939003,-25.939003,-16.939003,-13.939003,-10.939003,-3.939003,3.060997,6.060997,-5.939003,-7.939003,16.060997,24.060997,25.060997,20.060997,8.060997,-10.939003,-63.939003,-100.939,-93.939,-94.939,-99.939,-101.939,-102.939,-103.939,-102.939,-101.939,-103.939,-103.939,-102.939,-102.939,-101.939,-100.939,-101.939,-102.939,-100.939,-100.939,-101.939,-100.939,-99.939,-100.939,-100.939,-99.939,-99.939,-99.939,-98.939,-98.939,-96.939,-91.939,-94.939,-96.939,-92.939,-94.939,-93.939,-59.939003,-22.939003,17.060997,28.060997,32.060997,26.060997,13.060997,-5.939003,-52.939003,-70.939,-59.939003,-51.939003,-42.939003,-37.939003,-35.939003,-35.939003,-37.939003,-41.939003,-50.939003,-53.939003,-49.939003,-25.939003,-13.939003,-4.939003,-5.939003,-3.939003,0.06099701,-3.939003,-6.939003,-4.939003,-5.939003,-8.939003,-13.939003,-18.939003,-20.939003,-23.939003,-27.939003,-32.939003,-34.939003,-34.939003,-37.939003,-45.939003,-57.939003,-60.939003,-63.939003,-65.939,-68.939,-72.939,-80.939,-71.939,-48.939003,-70.939,-91.939,-96.939,-86.939,-72.939,-93.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-100.939,-101.939,-81.939,-19.939003,39.060997,92.061,78.061,44.060997,-11.939003,-63.939003,-98.939,-92.939,-94.939,-98.939,-97.939,-97.939,-97.939,-96.939,-95.939,-91.939,-90.939,-90.939,-85.939,-81.939,-79.939,-77.939,-73.939,-67.939,-65.939,-63.939003,-57.939003,-50.939003,-45.939003,-41.939003,-40.939003,-39.939003,-35.939003,-29.939003,-23.939003,-22.939003,-24.939003,-20.939003,-17.939003,-13.939003,-15.939003,-19.939003,-23.939003,-30.939003,-42.939003,-48.939003,-51.939003,-45.939003,-39.939003,-35.939003,-34.939003,-32.939003,-30.939003,-25.939003,-23.939003,-27.939003,-25.939003,-22.939003,-19.939003,-13.939003,-1.939003,1.060997,5.060997,10.060997,12.060997,11.060997,13.060997,17.060997,22.060997,25.060997,28.060997,32.060997,34.060997,36.060997,40.060997,45.060997,51.060997,51.060997,49.060997,45.060997,43.060997,42.060997,48.060997,51.060997,51.060997,48.060997,46.060997,51.060997,54.060997,57.060997,51.060997,48.060997,46.060997,50.060997,53.060997,52.060997,54.060997,57.060997,63.060997,66.061,67.061,64.061,61.060997,61.060997,60.060997,60.060997,62.060997,58.060997,48.060997,51.060997,58.060997,72.061,73.061,70.061,68.061,65.061,61.060997,54.060997,47.060997,44.060997,44.060997,45.060997,50.060997,52.060997,51.060997,52.060997,52.060997,52.060997,48.060997,43.060997,42.060997,40.060997,39.060997,38.060997,37.060997,36.060997,30.060997,21.060997,15.060997,13.060997,18.060997,19.060997,17.060997,15.060997,11.060997,6.060997,-0.939003,-5.939003,-7.939003,-10.939003,-14.939003,-20.939003,-25.939003,-30.939003,-38.939003,-44.939003,-45.939003,-48.939003,-52.939003,-58.939003,-63.939003,-68.939,-68.939,-71.939,-77.939,-85.939,-91.939,-92.939,-94.939,-97.939,-98.939,-98.939,-98.939,-98.939,-98.939,-96.939,-95.939,-92.939,-73.939,-49.939003,-18.939003,-14.939003,-22.939003,-50.939003,-58.939003,-56.939003,-55.939003,-53.939003,-47.939003,-45.939003,-43.939003,-40.939003,-41.939003,-43.939003,-39.939003,-37.939003,-38.939003,-38.939003,-38.939003,-33.939003,-33.939003,-33.939003,-36.939003,-37.939003,-35.939003,-38.939003,-42.939003,-42.939003,-45.939003,-47.939003,-48.939003,-48.939003,-47.939003,-47.939003,-48.939003,-49.939003,-47.939003,-45.939003,-48.939003,-33.939003,-1.939003,2.060997,-2.939003,-17.939003,-10.939003,3.060997,2.060997,4.060997,10.060997,4.060997,-7.939003,-27.939003,-16.939003,8.060997,24.060997,26.060997,12.060997,10.060997,8.060997,7.060997,-0.939003,-8.939003,-6.939003,-5.939003,-6.939003,-16.939003,-21.939003,-20.939003,-26.939003,-36.939003,-47.939003,-48.939003,-40.939003,-37.939003,-36.939003,-34.939003,-41.939003,-49.939003,-44.939003,-39.939003,-34.939003,-35.939003,-31.939003,-23.939003,-22.939003,-25.939003,-23.939003,-23.939003,-25.939003,-24.939003,-28.939003,-43.939003,-12.939003,35.060997,48.060997,43.060997,21.060997,26.060997,28.060997,15.060997,14.060997,19.060997,33.060997,39.060997,36.060997,33.060997,32.060997,33.060997,4.060997,-30.939003,-4.939003,22.060997,50.060997,43.060997,38.060997,43.060997,42.060997,39.060997,42.060997,46.060997,50.060997,-7.939003,-61.939003,-91.939,-100.939,-100.939,-95.939,-94.939,-93.939,-95.939,-98.939,-100.939,-97.939,-93.939,-93.939,-95.939,-98.939,-89.939,-76.939,-61.939003,-71.939,-89.939,-84.939,-82.939,-81.939,-85.939,-87.939,-82.939,-75.939,-68.939,-71.939,-72.939,-69.939,-70.939,-71.939,-75.939,-73.939,-70.939,-71.939,-69.939,-63.939003,-40.939003,-22.939003,-24.939003,-33.939003,-45.939003,-53.939003,-59.939003,-60.939003,-59.939003,-58.939003,-53.939003,-54.939003,-57.939003,-52.939003,-50.939003,-51.939003,-50.939003,-48.939003,-45.939003,-46.939003,-49.939003,-48.939003,-48.939003,-48.939003,-47.939003,-46.939003,-47.939003,-48.939003,-50.939003,-50.939003,-49.939003,-48.939003,-52.939003,-55.939003,-54.939003,-51.939003,-48.939003,-52.939003,-56.939003,-58.939003,-60.939003,-61.939003,-63.939003,-56.939003,-46.939003,-46.939003,-45.939003,-43.939003,-54.939003,-62.939003,-64.939,-66.939,-67.939,-64.939,-64.939,-69.939,-73.939,-74.939,-76.939,-75.939,-73.939,-78.939,-81.939,-80.939,-81.939,-83.939,-84.939,-84.939,-85.939,-88.939,-90.939,-93.939,-94.939,-92.939,-85.939,-77.939,-73.939,-83.939,-90.939,-95.939,-97.939,-98.939,-96.939,-97.939,-99.939,-100.939,-79.939,-39.939003,-25.939003,-17.939003,-33.939003,-35.939003,-38.939003,-41.939003,-42.939003,-43.939003,-46.939003,-50.939003,-57.939003,-60.939003,-60.939003,-62.939003,-66.939,-73.939,-83.939,-92.939,-93.939,-96.939,-99.939,-98.939,-99.939,-101.939,-101.939,-100.939,-99.939,-96.939,-94.939,-94.939,-94.939,-97.939,-97.939,-95.939,-89.939,-84.939,-80.939,-81.939,-79.939,-74.939,-65.939,-51.939003,-23.939003,-28.939003,-45.939003,-54.939003,-55.939003,-48.939003,-45.939003,-42.939003,-38.939003,-37.939003,-36.939003,-28.939003,-24.939003,-23.939003,-27.939003,-29.939003,-27.939003,-24.939003,-22.939003,-24.939003,-26.939003,-28.939003,-32.939003,-35.939003,-36.939003,-37.939003,-38.939003,-43.939003,-47.939003,-53.939003,-51.939003,-50.939003,-53.939003,-56.939003,-59.939003,-63.939003,-64.939,-63.939003,-6.939003,38.060997,39.060997,49.060997,57.060997,25.060997,-19.939003,-76.939,-85.939,-86.939,-86.939,-88.939,-90.939,-91.939,-93.939,-95.939,-96.939,-98.939,-100.939,-100.939,-101.939,-102.939,-103.939,-103.939,-102.939,-102.939,-100.939,-99.939,-98.939,-99.939,-101.939,-101.939,-99.939,-98.939,-97.939,-95.939,-93.939,-90.939,-88.939,-87.939,-85.939,-81.939,-73.939,-68.939,-65.939,-63.939003,-60.939003,-56.939003,-51.939003,-46.939003,-43.939003,-41.939003,-39.939003,-38.939003,-35.939003,-33.939003,-29.939003,-27.939003,-24.939003,-24.939003,-24.939003,-24.939003,-24.939003,-26.939003,-28.939003,-28.939003,-20.939003,-9.939003,-0.939003,1.060997,-2.939003,-9.939003,-29.939003,-45.939003,-49.939003,-52.939003,-55.939003,-57.939003,-60.939003,-65.939,-66.939,-67.939,-69.939,-71.939,-73.939,-74.939,-75.939,-78.939,-80.939,-82.939,-84.939,-86.939,-89.939,-90.939,-91.939,-92.939,-94.939,-95.939,-97.939,-99.939,-101.939,-102.939,-103.939,-102.939,-103.939,-102.939,-101.939,-100.939,-101.939,-101.939,-100.939,-99.939,-100.939,-99.939,-93.939,-91.939,-91.939,-87.939,-83.939,-79.939,-74.939,-69.939,-67.939,-63.939003,-59.939003,-55.939003,-52.939003,-50.939003,-38.939003,-27.939003,-21.939003,-18.939003,-18.939003,-16.939003,-18.939003,-23.939003,-25.939003,-27.939003,-28.939003,-27.939003,-25.939003,-26.939003,-27.939003,-25.939003,-24.939003,-23.939003,-19.939003,-17.939003,-15.939003,-11.939003,-5.939003,4.060997,3.060997,1.060997,1.060997,3.060997,7.060997,12.060997,16.060997,17.060997,18.060997,21.060997,26.060997,28.060997,28.060997,29.060997,31.060997,33.060997,14.060997,5.060997,22.060997,34.060997,44.060997,39.060997,41.060997,49.060997,17.060997,-2.939003,17.060997,24.060997,25.060997,21.060997,9.060997,-9.939003,-62.939003,-102.939,-100.939,-100.939,-101.939,-102.939,-103.939,-102.939,-103.939,-102.939,-102.939,-100.939,-95.939,-96.939,-95.939,-93.939,-90.939,-88.939,-86.939,-82.939,-78.939,-75.939,-72.939,-69.939,-66.939,-63.939003,-60.939003,-56.939003,-51.939003,-47.939003,-43.939003,-40.939003,-37.939003,-33.939003,-31.939003,-33.939003,-34.939003,-27.939003,-11.939003,13.060997,26.060997,33.060997,27.060997,16.060997,-0.939003,-49.939003,-59.939003,-32.939003,-26.939003,-24.939003,-28.939003,-30.939003,-30.939003,-21.939003,-16.939003,-16.939003,-15.939003,-20.939003,-40.939003,-51.939003,-58.939003,-58.939003,-58.939003,-60.939003,-64.939,-68.939,-69.939,-70.939,-71.939,-73.939,-75.939,-75.939,-76.939,-78.939,-79.939,-80.939,-80.939,-81.939,-84.939,-88.939,-89.939,-90.939,-90.939,-91.939,-93.939,-95.939,-84.939,-58.939003,-78.939,-98.939,-100.939,-88.939,-72.939,-93.939,-102.939,-102.939,-100.939,-100.939,-99.939,-95.939,-90.939,-89.939,-87.939,-84.939,-80.939,-66.939,-30.939003,3.060997,33.060997,19.060997,1.060997,-22.939003,-44.939003,-56.939003,-44.939003,-43.939003,-45.939003,-39.939003,-36.939003,-35.939003,-33.939003,-32.939003,-32.939003,-31.939003,-31.939003,-26.939003,-24.939003,-24.939003,-25.939003,-26.939003,-29.939003,-28.939003,-26.939003,-24.939003,-25.939003,-31.939003,-34.939003,-38.939003,-38.939003,-38.939003,-36.939003,-38.939003,-39.939003,-41.939003,-44.939003,-48.939003,-52.939003,-56.939003,-58.939003,-62.939003,-65.939,-65.939,-51.939003,-40.939003,-42.939003,-52.939003,-60.939003,-33.939003,-12.939003,0.06099701,-6.939003,-11.939003,-7.939003,-10.939003,-14.939003,-10.939003,-8.939003,-8.939003,-6.939003,-9.939003,-19.939003,-22.939003,-21.939003,-12.939003,-7.939003,-8.939003,-7.939003,-8.939003,-12.939003,-14.939003,-16.939003,-16.939003,-17.939003,-16.939003,-19.939003,-21.939003,-24.939003,-24.939003,-23.939003,-22.939003,-21.939003,-17.939003,-20.939003,-22.939003,-21.939003,-22.939003,-22.939003,-18.939003,-16.939003,-16.939003,-20.939003,-23.939003,-19.939003,-18.939003,-19.939003,-19.939003,-18.939003,-19.939003,-18.939003,-17.939003,-13.939003,-15.939003,-20.939003,-30.939003,-15.939003,25.060997,20.060997,8.060997,-9.939003,-13.939003,-12.939003,-19.939003,-22.939003,-23.939003,-37.939003,-49.939003,-50.939003,-50.939003,-51.939003,-47.939003,-46.939003,-47.939003,-49.939003,-51.939003,-51.939003,-51.939003,-51.939003,-54.939003,-54.939003,-53.939003,-54.939003,-54.939003,-48.939003,-28.939003,-4.939003,-9.939003,-23.939003,-45.939003,-55.939003,-63.939003,-63.939003,-65.939,-66.939,-69.939,-70.939,-71.939,-72.939,-73.939,-75.939,-77.939,-79.939,-81.939,-81.939,-76.939,-77.939,-77.939,-76.939,-72.939,-67.939,-67.939,-66.939,-65.939,-64.939,-62.939003,-59.939003,-57.939003,-58.939003,-55.939003,-52.939003,-48.939003,-48.939003,-49.939003,-49.939003,-46.939003,-44.939003,-48.939003,-48.939003,-43.939003,-41.939003,-41.939003,-43.939003,-45.939003,-46.939003,-46.939003,-48.939003,-52.939003,-54.939003,-54.939003,-50.939003,-56.939003,-66.939,-66.939,-67.939,-67.939,-72.939,-77.939,-78.939,-79.939,-80.939,-81.939,-81.939,-80.939,-81.939,-83.939,-83.939,-84.939,-82.939,-63.939003,-49.939003,-42.939003,-46.939003,-47.939003,-37.939003,-40.939003,-47.939003,-55.939003,-35.939003,11.060997,11.060997,-1.939003,-28.939003,-25.939003,-11.939003,-15.939003,-16.939003,-14.939003,-18.939003,-25.939003,-40.939003,-37.939003,-28.939003,-22.939003,-25.939003,-39.939003,-38.939003,-37.939003,-37.939003,-42.939003,-48.939003,-52.939003,-52.939003,-48.939003,-52.939003,-54.939003,-51.939003,-45.939003,-39.939003,-42.939003,-42.939003,-39.939003,-38.939003,-38.939003,-38.939003,-41.939003,-45.939003,-43.939003,-42.939003,-39.939003,-33.939003,-26.939003,-22.939003,-23.939003,-26.939003,-24.939003,-22.939003,-24.939003,-23.939003,-27.939003,-40.939003,-33.939003,-18.939003,-22.939003,-27.939003,-31.939003,-27.939003,-25.939003,-31.939003,-32.939003,-31.939003,-23.939003,-20.939003,-23.939003,-23.939003,-22.939003,-15.939003,-28.939003,-45.939003,-32.939003,-19.939003,-6.939003,-7.939003,-6.939003,-4.939003,-4.939003,-5.939003,-1.939003,5.060997,15.060997,-22.939003,-59.939003,-81.939,-85.939,-81.939,-80.939,-80.939,-78.939,-80.939,-83.939,-86.939,-86.939,-86.939,-85.939,-87.939,-91.939,-89.939,-81.939,-61.939003,-72.939,-94.939,-95.939,-95.939,-96.939,-97.939,-97.939,-94.939,-88.939,-81.939,-86.939,-89.939,-88.939,-90.939,-92.939,-94.939,-93.939,-91.939,-92.939,-85.939,-69.939,-17.939003,23.060997,32.060997,-4.939003,-56.939003,-77.939,-88.939,-87.939,-88.939,-88.939,-86.939,-87.939,-88.939,-86.939,-85.939,-86.939,-85.939,-85.939,-84.939,-84.939,-85.939,-83.939,-82.939,-83.939,-84.939,-83.939,-80.939,-81.939,-82.939,-83.939,-81.939,-77.939,-81.939,-84.939,-83.939,-78.939,-74.939,-75.939,-76.939,-79.939,-76.939,-74.939,-76.939,-61.939003,-41.939003,-30.939003,-31.939003,-43.939003,-59.939003,-71.939,-69.939,-70.939,-72.939,-69.939,-68.939,-69.939,-69.939,-69.939,-69.939,-68.939,-66.939,-69.939,-69.939,-65.939,-64.939,-64.939,-64.939,-65.939,-66.939,-65.939,-65.939,-68.939,-67.939,-65.939,-62.939003,-60.939003,-59.939003,-61.939003,-62.939003,-63.939003,-62.939003,-63.939003,-67.939,-68.939,-68.939,-69.939,-64.939,-54.939003,-47.939003,-42.939003,-34.939003,-36.939003,-38.939003,-41.939003,-44.939003,-46.939003,-48.939003,-51.939003,-61.939003,-64.939,-65.939,-69.939,-78.939,-90.939,-97.939,-101.939,-97.939,-93.939,-91.939,-88.939,-87.939,-86.939,-84.939,-82.939,-79.939,-76.939,-73.939,-71.939,-70.939,-72.939,-71.939,-68.939,-62.939003,-56.939003,-52.939003,-51.939003,-49.939003,-44.939003,-40.939003,-35.939003,-25.939003,-25.939003,-29.939003,-34.939003,-33.939003,-26.939003,-26.939003,-24.939003,-22.939003,-24.939003,-27.939003,-23.939003,-20.939003,-20.939003,-27.939003,-32.939003,-33.939003,-32.939003,-32.939003,-38.939003,-42.939003,-45.939003,-48.939003,-53.939003,-57.939003,-60.939003,-63.939003,-68.939,-72.939,-79.939,-79.939,-79.939,-82.939,-85.939,-88.939,-92.939,-93.939,-93.939,-1.939003,70.061,73.061,82.061,89.061,52.060997,-9.939003,-94.939,-103.939,-103.939,-103.939,-103.939,-103.939,-101.939,-100.939,-100.939,-99.939,-98.939,-97.939,-95.939,-93.939,-91.939,-90.939,-89.939,-87.939,-84.939,-80.939,-77.939,-73.939,-73.939,-73.939,-72.939,-71.939,-69.939,-68.939,-65.939,-62.939003,-57.939003,-54.939003,-55.939003,-56.939003,-54.939003,-46.939003,-39.939003,-36.939003,-36.939003,-34.939003,-32.939003,-28.939003,-26.939003,-26.939003,-25.939003,-24.939003,-25.939003,-25.939003,-27.939003,-25.939003,-24.939003,-22.939003,-25.939003,-29.939003,-32.939003,-33.939003,-35.939003,-41.939003,-39.939003,-17.939003,9.060997,35.060997,35.060997,33.060997,28.060997,-29.939003,-75.939,-81.939,-85.939,-89.939,-90.939,-91.939,-94.939,-96.939,-98.939,-101.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-100.939,-100.939,-99.939,-97.939,-97.939,-97.939,-95.939,-95.939,-94.939,-92.939,-89.939,-86.939,-85.939,-82.939,-77.939,-75.939,-74.939,-74.939,-73.939,-70.939,-70.939,-67.939,-60.939003,-58.939003,-58.939003,-54.939003,-49.939003,-46.939003,-41.939003,-37.939003,-38.939003,-36.939003,-33.939003,-29.939003,-27.939003,-25.939003,-23.939003,-22.939003,-21.939003,-19.939003,-18.939003,-22.939003,-24.939003,-24.939003,-24.939003,-25.939003,-25.939003,-22.939003,-18.939003,-16.939003,-13.939003,-7.939003,-7.939003,-8.939003,-11.939003,-6.939003,2.060997,11.060997,23.060997,39.060997,36.060997,32.060997,31.060997,37.060997,44.060997,46.060997,48.060997,50.060997,53.060997,57.060997,61.060997,62.060997,61.060997,61.060997,63.060997,68.061,37.060997,19.060997,38.060997,54.060997,68.061,57.060997,56.060997,63.060997,23.060997,-2.939003,15.060997,23.060997,26.060997,22.060997,10.060997,-10.939003,-58.939003,-92.939,-85.939,-82.939,-80.939,-80.939,-79.939,-76.939,-75.939,-73.939,-71.939,-68.939,-64.939,-65.939,-63.939003,-60.939003,-56.939003,-53.939003,-52.939003,-48.939003,-44.939003,-44.939003,-42.939003,-39.939003,-36.939003,-34.939003,-33.939003,-31.939003,-27.939003,-24.939003,-22.939003,-22.939003,-20.939003,-18.939003,-18.939003,-23.939003,-28.939003,-23.939003,-9.939003,14.060997,27.060997,34.060997,29.060997,19.060997,4.060997,-48.939003,-64.939,-45.939003,-41.939003,-42.939003,-48.939003,-50.939003,-45.939003,0.06099701,21.060997,16.060997,23.060997,10.060997,-54.939003,-80.939,-91.939,-90.939,-91.939,-94.939,-97.939,-100.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-100.939,-99.939,-98.939,-97.939,-97.939,-96.939,-94.939,-93.939,-91.939,-89.939,-87.939,-76.939,-55.939003,-67.939,-79.939,-78.939,-68.939,-56.939003,-67.939,-71.939,-69.939,-69.939,-69.939,-67.939,-62.939003,-56.939003,-55.939003,-53.939003,-50.939003,-47.939003,-41.939003,-32.939003,-20.939003,-7.939003,-15.939003,-21.939003,-26.939003,-31.939003,-32.939003,-25.939003,-24.939003,-24.939003,-21.939003,-19.939003,-20.939003,-20.939003,-21.939003,-25.939003,-25.939003,-26.939003,-26.939003,-27.939003,-28.939003,-29.939003,-32.939003,-38.939003,-39.939003,-38.939003,-37.939003,-40.939003,-49.939003,-53.939003,-58.939003,-58.939003,-59.939003,-60.939003,-63.939003,-65.939,-67.939,-71.939,-75.939,-82.939,-84.939,-85.939,-89.939,-87.939,-80.939,-50.939003,-28.939003,-33.939003,-51.939003,-68.939,-29.939003,1.060997,24.060997,11.060997,2.060997,8.060997,4.060997,-2.939003,0.06099701,-1.939003,-7.939003,-8.939003,-15.939003,-31.939003,-36.939003,-34.939003,-27.939003,-24.939003,-26.939003,-29.939003,-32.939003,-39.939003,-42.939003,-44.939003,-46.939003,-48.939003,-50.939003,-52.939003,-54.939003,-55.939003,-54.939003,-53.939003,-55.939003,-55.939003,-50.939003,-51.939003,-51.939003,-49.939003,-47.939003,-45.939003,-42.939003,-43.939003,-49.939003,-58.939003,-63.939003,-59.939003,-60.939003,-62.939003,-64.939,-66.939,-68.939,-67.939,-65.939,-63.939003,-65.939,-70.939,-83.939,-61.939003,-5.939003,-15.939003,-36.939003,-63.939003,-66.939,-60.939003,-67.939,-70.939,-71.939,-82.939,-91.939,-92.939,-93.939,-95.939,-94.939,-94.939,-94.939,-99.939,-102.939,-101.939,-98.939,-95.939,-99.939,-99.939,-94.939,-96.939,-93.939,-79.939,-47.939003,-11.939003,-17.939003,-35.939003,-64.939,-77.939,-84.939,-81.939,-80.939,-81.939,-80.939,-79.939,-80.939,-79.939,-78.939,-75.939,-77.939,-79.939,-80.939,-76.939,-67.939,-68.939,-70.939,-68.939,-62.939003,-54.939003,-57.939003,-57.939003,-55.939003,-52.939003,-49.939003,-47.939003,-46.939003,-47.939003,-45.939003,-41.939003,-38.939003,-40.939003,-42.939003,-42.939003,-43.939003,-44.939003,-37.939003,-29.939003,-20.939003,-23.939003,-32.939003,-51.939003,-58.939003,-60.939003,-60.939003,-64.939,-70.939,-72.939,-73.939,-70.939,-76.939,-85.939,-87.939,-89.939,-88.939,-93.939,-98.939,-100.939,-99.939,-99.939,-98.939,-96.939,-93.939,-93.939,-92.939,-91.939,-89.939,-85.939,-62.939003,-48.939003,-42.939003,-46.939003,-46.939003,-36.939003,-39.939003,-46.939003,-52.939003,-38.939003,-6.939003,-9.939003,-20.939003,-39.939003,-40.939003,-35.939003,-37.939003,-37.939003,-36.939003,-36.939003,-39.939003,-45.939003,-48.939003,-48.939003,-45.939003,-47.939003,-55.939003,-54.939003,-52.939003,-49.939003,-49.939003,-51.939003,-55.939003,-54.939003,-47.939003,-46.939003,-45.939003,-37.939003,-35.939003,-36.939003,-44.939003,-46.939003,-44.939003,-45.939003,-45.939003,-46.939003,-46.939003,-45.939003,-43.939003,-42.939003,-42.939003,-33.939003,-26.939003,-24.939003,-25.939003,-27.939003,-26.939003,-25.939003,-26.939003,-25.939003,-29.939003,-41.939003,-37.939003,-27.939003,-37.939003,-42.939003,-42.939003,-41.939003,-40.939003,-42.939003,-43.939003,-42.939003,-40.939003,-40.939003,-41.939003,-42.939003,-42.939003,-37.939003,-41.939003,-49.939003,-43.939003,-40.939003,-37.939003,-35.939003,-32.939003,-31.939003,-31.939003,-31.939003,-28.939003,-22.939003,-12.939003,-36.939003,-59.939003,-68.939,-69.939,-65.939,-66.939,-67.939,-65.939,-67.939,-69.939,-70.939,-71.939,-73.939,-71.939,-73.939,-76.939,-77.939,-73.939,-58.939003,-67.939,-84.939,-86.939,-87.939,-88.939,-87.939,-86.939,-83.939,-81.939,-78.939,-80.939,-81.939,-82.939,-85.939,-87.939,-87.939,-88.939,-87.939,-90.939,-82.939,-66.939,-17.939003,21.060997,31.060997,-5.939003,-58.939003,-80.939,-90.939,-88.939,-89.939,-90.939,-91.939,-91.939,-91.939,-92.939,-93.939,-93.939,-96.939,-97.939,-95.939,-95.939,-96.939,-96.939,-96.939,-96.939,-98.939,-99.939,-97.939,-98.939,-99.939,-100.939,-98.939,-94.939,-97.939,-100.939,-99.939,-95.939,-92.939,-91.939,-91.939,-93.939,-89.939,-86.939,-88.939,-74.939,-53.939003,-39.939003,-41.939003,-57.939003,-73.939,-84.939,-81.939,-82.939,-83.939,-81.939,-80.939,-80.939,-78.939,-78.939,-78.939,-77.939,-75.939,-77.939,-75.939,-72.939,-71.939,-70.939,-69.939,-70.939,-71.939,-69.939,-68.939,-70.939,-68.939,-66.939,-63.939003,-59.939003,-58.939003,-62.939003,-63.939003,-63.939003,-61.939003,-61.939003,-64.939,-66.939,-66.939,-64.939,-59.939003,-52.939003,-48.939003,-46.939003,-34.939003,-35.939003,-37.939003,-40.939003,-44.939003,-49.939003,-49.939003,-52.939003,-61.939003,-65.939,-68.939,-76.939,-88.939,-101.939,-102.939,-98.939,-89.939,-78.939,-69.939,-67.939,-62.939003,-55.939003,-49.939003,-42.939003,-37.939003,-31.939003,-27.939003,-22.939003,-21.939003,-23.939003,-21.939003,-19.939003,-16.939003,-12.939003,-9.939003,-6.939003,-6.939003,-9.939003,-17.939003,-24.939003,-24.939003,-26.939003,-28.939003,-33.939003,-34.939003,-31.939003,-37.939003,-43.939003,-47.939003,-55.939003,-62.939003,-67.939,-69.939,-72.939,-78.939,-84.939,-86.939,-89.939,-93.939,-96.939,-97.939,-97.939,-95.939,-96.939,-98.939,-99.939,-99.939,-99.939,-99.939,-101.939,-100.939,-99.939,-97.939,-97.939,-98.939,-97.939,-98.939,-99.939,-7.939003,65.061,74.061,81.061,86.061,69.061,9.060997,-93.939,-103.939,-103.939,-103.939,-103.939,-103.939,-98.939,-95.939,-95.939,-91.939,-87.939,-85.939,-80.939,-74.939,-67.939,-63.939003,-60.939003,-56.939003,-50.939003,-43.939003,-35.939003,-28.939003,-23.939003,-19.939003,-15.939003,-14.939003,-13.939003,-12.939003,-10.939003,-7.939003,-2.939003,-1.939003,-4.939003,-11.939003,-18.939003,-17.939003,-14.939003,-12.939003,-17.939003,-20.939003,-25.939003,-31.939003,-38.939003,-45.939003,-50.939003,-53.939003,-58.939003,-63.939003,-72.939,-76.939,-80.939,-82.939,-88.939,-94.939,-96.939,-94.939,-89.939,-93.939,-77.939,-13.939003,42.060997,87.061,78.061,76.061,80.061,-21.939003,-100.939,-101.939,-101.939,-101.939,-102.939,-100.939,-98.939,-101.939,-103.939,-103.939,-101.939,-99.939,-102.939,-103.939,-103.939,-102.939,-102.939,-101.939,-97.939,-94.939,-94.939,-91.939,-86.939,-85.939,-84.939,-80.939,-78.939,-75.939,-69.939,-62.939003,-52.939003,-49.939003,-43.939003,-32.939003,-26.939003,-22.939003,-23.939003,-21.939003,-15.939003,-12.939003,-8.939003,-2.939003,-1.939003,-3.939003,-1.939003,-0.939003,0.06099701,-0.939003,-2.939003,-9.939003,-14.939003,-18.939003,-18.939003,-20.939003,-22.939003,-20.939003,-21.939003,-26.939003,-16.939003,0.06099701,20.060997,30.060997,30.060997,36.060997,41.060997,43.060997,44.060997,44.060997,48.060997,55.060997,65.061,47.060997,26.060997,3.060997,15.060997,40.060997,47.060997,61.060997,84.061,71.061,60.060997,56.060997,60.060997,65.061,58.060997,55.060997,58.060997,64.061,68.061,67.061,65.061,60.060997,61.060997,65.061,73.061,40.060997,16.060997,31.060997,47.060997,61.060997,52.060997,47.060997,48.060997,13.060997,-7.939003,10.060997,21.060997,28.060997,25.060997,10.060997,-15.939003,-50.939003,-69.939,-50.939003,-41.939003,-37.939003,-34.939003,-30.939003,-25.939003,-19.939003,-12.939003,-10.939003,-9.939003,-9.939003,-8.939003,-5.939003,-2.939003,0.06099701,2.060997,2.060997,0.06099701,-0.939003,-6.939003,-9.939003,-10.939003,-10.939003,-11.939003,-18.939003,-22.939003,-24.939003,-29.939003,-33.939003,-37.939003,-45.939003,-50.939003,-53.939003,-64.939,-75.939,-48.939003,-16.939003,20.060997,32.060997,36.060997,30.060997,22.060997,9.060997,-49.939003,-84.939,-96.939,-96.939,-95.939,-97.939,-95.939,-80.939,28.060997,70.061,47.060997,62.060997,43.060997,-67.939,-100.939,-102.939,-102.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-99.939,-95.939,-92.939,-87.939,-85.939,-85.939,-81.939,-77.939,-72.939,-66.939,-60.939003,-55.939003,-49.939003,-40.939003,-38.939003,-34.939003,-30.939003,-27.939003,-23.939003,-14.939003,-8.939003,-6.939003,-10.939003,-11.939003,-6.939003,-4.939003,-1.939003,-0.939003,-0.939003,0.06099701,-1.939003,-7.939003,-25.939003,-31.939003,-31.939003,-27.939003,-24.939003,-21.939003,-23.939003,-27.939003,-35.939003,-36.939003,-35.939003,-43.939003,-48.939003,-52.939003,-56.939003,-62.939003,-70.939,-72.939,-75.939,-85.939,-90.939,-89.939,-88.939,-89.939,-94.939,-96.939,-97.939,-97.939,-97.939,-98.939,-98.939,-99.939,-99.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-101.939,-101.939,-101.939,-102.939,-96.939,-85.939,-45.939003,-16.939003,-18.939003,-38.939003,-58.939003,-23.939003,9.060997,41.060997,28.060997,17.060997,19.060997,18.060997,14.060997,14.060997,9.060997,1.060997,-4.939003,-12.939003,-25.939003,-29.939003,-29.939003,-32.939003,-32.939003,-33.939003,-39.939003,-45.939003,-47.939003,-48.939003,-49.939003,-47.939003,-47.939003,-49.939003,-47.939003,-47.939003,-46.939003,-46.939003,-46.939003,-50.939003,-50.939003,-45.939003,-45.939003,-41.939003,-34.939003,-22.939003,-11.939003,-21.939003,-34.939003,-52.939003,-61.939003,-66.939,-67.939,-70.939,-73.939,-72.939,-75.939,-80.939,-82.939,-84.939,-88.939,-90.939,-89.939,-95.939,-80.939,-44.939003,-58.939003,-75.939,-88.939,-83.939,-71.939,-76.939,-79.939,-82.939,-80.939,-79.939,-79.939,-82.939,-86.939,-89.939,-91.939,-91.939,-96.939,-99.939,-97.939,-93.939,-88.939,-93.939,-92.939,-84.939,-85.939,-79.939,-55.939003,-26.939003,1.060997,-9.939003,-22.939003,-39.939003,-45.939003,-47.939003,-38.939003,-35.939003,-36.939003,-34.939003,-32.939003,-33.939003,-32.939003,-28.939003,-20.939003,-24.939003,-32.939003,-34.939003,-30.939003,-19.939003,-23.939003,-30.939003,-34.939003,-34.939003,-31.939003,-40.939003,-44.939003,-46.939003,-48.939003,-52.939003,-56.939003,-61.939003,-66.939,-67.939,-66.939,-67.939,-75.939,-79.939,-75.939,-84.939,-92.939,-38.939003,9.060997,49.060997,37.060997,2.060997,-73.939,-97.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-101.939,-101.939,-102.939,-102.939,-102.939,-101.939,-101.939,-99.939,-94.939,-91.939,-88.939,-81.939,-73.939,-73.939,-71.939,-66.939,-62.939003,-55.939003,-47.939003,-43.939003,-47.939003,-46.939003,-45.939003,-45.939003,-44.939003,-43.939003,-39.939003,-43.939003,-55.939003,-58.939003,-57.939003,-51.939003,-57.939003,-68.939,-64.939,-60.939003,-55.939003,-51.939003,-48.939003,-44.939003,-47.939003,-51.939003,-44.939003,-39.939003,-35.939003,-37.939003,-36.939003,-26.939003,-21.939003,-17.939003,-14.939003,-9.939003,-4.939003,0.06099701,7.060997,19.060997,2.060997,-27.939003,-52.939003,-61.939003,-55.939003,-56.939003,-57.939003,-59.939003,-55.939003,-49.939003,-42.939003,-40.939003,-42.939003,-36.939003,-31.939003,-30.939003,-29.939003,-26.939003,-29.939003,-31.939003,-32.939003,-31.939003,-34.939003,-46.939003,-24.939003,8.060997,4.060997,-2.939003,-11.939003,-14.939003,-15.939003,-16.939003,-16.939003,-15.939003,-19.939003,-20.939003,-17.939003,-22.939003,-27.939003,-32.939003,-36.939003,-40.939003,-38.939003,-39.939003,-40.939003,-40.939003,-39.939003,-38.939003,-37.939003,-38.939003,-37.939003,-35.939003,-33.939003,-49.939003,-60.939003,-53.939003,-51.939003,-50.939003,-53.939003,-55.939003,-55.939003,-55.939003,-54.939003,-51.939003,-52.939003,-54.939003,-52.939003,-52.939003,-53.939003,-53.939003,-53.939003,-53.939003,-54.939003,-58.939003,-58.939003,-58.939003,-57.939003,-55.939003,-53.939003,-50.939003,-54.939003,-58.939003,-52.939003,-50.939003,-53.939003,-54.939003,-54.939003,-56.939003,-57.939003,-58.939003,-63.939003,-61.939003,-53.939003,-39.939003,-29.939003,-26.939003,-36.939003,-51.939003,-61.939003,-65.939,-64.939,-64.939,-65.939,-66.939,-67.939,-68.939,-71.939,-72.939,-73.939,-81.939,-86.939,-79.939,-79.939,-82.939,-89.939,-91.939,-86.939,-90.939,-94.939,-98.939,-100.939,-101.939,-102.939,-101.939,-97.939,-100.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-102.939,-98.939,-96.939,-100.939,-93.939,-83.939,-74.939,-74.939,-85.939,-95.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-101.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-99.939,-100.939,-99.939,-98.939,-99.939,-96.939,-86.939,-76.939,-69.939,-86.939,-94.939,-95.939,-93.939,-90.939,-88.939,-90.939,-92.939,-84.939,-64.939,-33.939003,-29.939003,-28.939003,-41.939003,-42.939003,-43.939003,-46.939003,-50.939003,-55.939003,-59.939003,-63.939003,-66.939,-69.939,-74.939,-83.939,-86.939,-79.939,-67.939,-55.939003,-44.939003,-36.939003,-30.939003,-32.939003,-33.939003,-34.939003,-32.939003,-30.939003,-30.939003,-34.939003,-38.939003,-37.939003,-38.939003,-40.939003,-43.939003,-45.939003,-44.939003,-46.939003,-50.939003,-55.939003,-57.939003,-56.939003,-62.939003,-60.939003,-33.939003,-30.939003,-39.939003,-62.939003,-73.939,-71.939,-74.939,-77.939,-80.939,-83.939,-87.939,-89.939,-90.939,-92.939,-93.939,-96.939,-97.939,-99.939,-101.939,-103.939,-103.939,-102.939,-101.939,-101.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-96.939,-95.939,-95.939,-91.939,-85.939,-83.939,-82.939,-83.939,-21.939003,28.060997,36.060997,34.060997,27.060997,18.060997,-10.939003,-59.939003,-59.939003,-53.939003,-48.939003,-46.939003,-46.939003,-43.939003,-42.939003,-41.939003,-38.939003,-34.939003,-33.939003,-32.939003,-31.939003,-30.939003,-30.939003,-32.939003,-32.939003,-31.939003,-30.939003,-30.939003,-31.939003,-32.939003,-34.939003,-35.939003,-37.939003,-39.939003,-41.939003,-45.939003,-49.939003,-52.939003,-53.939003,-54.939003,-59.939003,-63.939003,-64.939,-63.939003,-62.939003,-64.939,-66.939,-68.939,-71.939,-75.939,-78.939,-81.939,-82.939,-85.939,-87.939,-91.939,-93.939,-96.939,-96.939,-98.939,-100.939,-99.939,-98.939,-97.939,-100.939,-85.939,-30.939003,34.060997,93.061,79.061,75.061,81.061,-15.939003,-93.939,-98.939,-96.939,-91.939,-87.939,-84.939,-81.939,-80.939,-78.939,-75.939,-69.939,-62.939003,-63.939003,-62.939003,-60.939003,-53.939003,-47.939003,-45.939003,-44.939003,-42.939003,-39.939003,-36.939003,-33.939003,-36.939003,-36.939003,-30.939003,-29.939003,-30.939003,-29.939003,-27.939003,-26.939003,-26.939003,-27.939003,-26.939003,-28.939003,-32.939003,-35.939003,-37.939003,-37.939003,-38.939003,-39.939003,-40.939003,-43.939003,-46.939003,-49.939003,-51.939003,-51.939003,-52.939003,-54.939003,-57.939003,-59.939003,-61.939003,-64.939,-66.939,-68.939,-41.939003,-22.939003,-27.939003,-15.939003,5.060997,40.060997,56.060997,52.060997,55.060997,58.060997,57.060997,58.060997,57.060997,58.060997,64.061,74.061,56.060997,33.060997,4.060997,15.060997,41.060997,44.060997,57.060997,79.061,67.061,53.060997,47.060997,45.060997,46.060997,39.060997,35.060997,35.060997,32.060997,30.060997,30.060997,26.060997,20.060997,17.060997,15.060997,17.060997,6.060997,-4.939003,-7.939003,-4.939003,-0.939003,-4.939003,-10.939003,-16.939003,-16.939003,-9.939003,13.060997,25.060997,31.060997,24.060997,8.060997,-16.939003,-41.939003,-52.939003,-26.939003,-25.939003,-32.939003,-33.939003,-34.939003,-34.939003,-35.939003,-36.939003,-37.939003,-37.939003,-39.939003,-45.939003,-47.939003,-48.939003,-50.939003,-53.939003,-54.939003,-55.939003,-56.939003,-59.939003,-61.939003,-61.939003,-61.939003,-62.939003,-65.939,-67.939,-68.939,-71.939,-73.939,-75.939,-78.939,-81.939,-82.939,-88.939,-91.939,-61.939003,-22.939003,24.060997,37.060997,40.060997,29.060997,23.060997,12.060997,-47.939003,-85.939,-101.939,-100.939,-99.939,-100.939,-93.939,-75.939,14.060997,47.060997,26.060997,31.060997,15.060997,-58.939003,-77.939,-75.939,-73.939,-69.939,-64.939,-62.939003,-60.939003,-60.939003,-57.939003,-52.939003,-50.939003,-48.939003,-48.939003,-47.939003,-47.939003,-44.939003,-42.939003,-39.939003,-36.939003,-36.939003,-36.939003,-34.939003,-32.939003,-32.939003,-30.939003,-26.939003,-26.939003,-27.939003,-28.939003,-29.939003,-28.939003,-27.939003,-29.939003,-30.939003,-34.939003,-35.939003,-34.939003,-34.939003,-35.939003,-39.939003,-44.939003,-47.939003,-46.939003,-47.939003,-48.939003,-53.939003,-51.939003,-28.939003,2.060997,35.060997,39.060997,23.060997,-11.939003,-44.939003,-69.939,-74.939,-74.939,-74.939,-77.939,-80.939,-82.939,-84.939,-87.939,-90.939,-92.939,-93.939,-98.939,-100.939,-99.939,-99.939,-99.939,-101.939,-101.939,-103.939,-101.939,-98.939,-95.939,-95.939,-95.939,-93.939,-87.939,-80.939,-80.939,-79.939,-75.939,-74.939,-72.939,-70.939,-65.939,-59.939003,-61.939003,-58.939003,-49.939003,-35.939003,-25.939003,-28.939003,-34.939003,-38.939003,-31.939003,-25.939003,-18.939003,-24.939003,-28.939003,-27.939003,-27.939003,-28.939003,-28.939003,-28.939003,-27.939003,-27.939003,-27.939003,-28.939003,-30.939003,-30.939003,-26.939003,-22.939003,-16.939003,-20.939003,-23.939003,-23.939003,-21.939003,-18.939003,-15.939003,-13.939003,-11.939003,-9.939003,-6.939003,-3.939003,-1.939003,0.06099701,-3.939003,-2.939003,3.060997,7.060997,11.060997,15.060997,5.060997,-13.939003,-40.939003,-61.939003,-76.939,-79.939,-78.939,-76.939,-75.939,-75.939,-79.939,-80.939,-79.939,-79.939,-81.939,-84.939,-89.939,-94.939,-89.939,-80.939,-69.939,-75.939,-82.939,-86.939,-85.939,-82.939,-82.939,-81.939,-80.939,-79.939,-78.939,-77.939,-80.939,-83.939,-84.939,-83.939,-80.939,-81.939,-80.939,-78.939,-69.939,-59.939003,-50.939003,-46.939003,-45.939003,-43.939003,-41.939003,-41.939003,-37.939003,-32.939003,-30.939003,-29.939003,-29.939003,-39.939003,-46.939003,-45.939003,-45.939003,-47.939003,-51.939003,-52.939003,-52.939003,-55.939003,-57.939003,-56.939003,-60.939003,-66.939,-68.939,-67.939,-64.939,-66.939,-69.939,-71.939,-71.939,-71.939,-75.939,-77.939,-78.939,-79.939,-81.939,-83.939,-85.939,-88.939,-88.939,-88.939,-89.939,-91.939,-93.939,-88.939,-91.939,-93.939,-40.939003,2.060997,34.060997,19.060997,-10.939003,-65.939,-81.939,-80.939,-77.939,-75.939,-74.939,-71.939,-69.939,-68.939,-65.939,-62.939003,-63.939003,-60.939003,-53.939003,-56.939003,-58.939003,-58.939003,-55.939003,-53.939003,-52.939003,-50.939003,-48.939003,-46.939003,-44.939003,-44.939003,-47.939003,-50.939003,-47.939003,-46.939003,-49.939003,-50.939003,-50.939003,-48.939003,-49.939003,-52.939003,-48.939003,-40.939003,-30.939003,-33.939003,-36.939003,-40.939003,-38.939003,-32.939003,-26.939003,-23.939003,-21.939003,-21.939003,-23.939003,-28.939003,-18.939003,-3.939003,1.060997,0.06099701,-3.939003,-5.939003,-6.939003,-7.939003,-8.939003,-8.939003,-0.939003,1.060997,-2.939003,-3.939003,-2.939003,4.060997,-10.939003,-33.939003,-49.939003,-53.939003,-48.939003,-46.939003,-44.939003,-46.939003,-47.939003,-47.939003,-44.939003,-41.939003,-39.939003,-33.939003,-29.939003,-29.939003,-27.939003,-24.939003,-28.939003,-31.939003,-31.939003,-32.939003,-34.939003,-42.939003,-10.939003,35.060997,31.060997,23.060997,12.060997,11.060997,9.060997,3.060997,6.060997,12.060997,15.060997,19.060997,23.060997,17.060997,11.060997,8.060997,-12.939003,-36.939003,-12.939003,4.060997,13.060997,10.060997,8.060997,11.060997,11.060997,8.060997,8.060997,7.060997,4.060997,-13.939003,-26.939003,-19.939003,-20.939003,-24.939003,-22.939003,-23.939003,-26.939003,-29.939003,-31.939003,-30.939003,-33.939003,-36.939003,-37.939003,-38.939003,-39.939003,-38.939003,-39.939003,-44.939003,-44.939003,-43.939003,-44.939003,-45.939003,-47.939003,-47.939003,-47.939003,-45.939003,-48.939003,-50.939003,-50.939003,-51.939003,-54.939003,-56.939003,-57.939003,-57.939003,-57.939003,-56.939003,-60.939003,-58.939003,-51.939003,-36.939003,-26.939003,-25.939003,-37.939003,-53.939003,-58.939003,-60.939003,-57.939003,-56.939003,-56.939003,-59.939003,-59.939003,-58.939003,-61.939003,-61.939003,-60.939003,-68.939,-72.939,-66.939,-66.939,-67.939,-64.939,-65.939,-70.939,-81.939,-90.939,-95.939,-97.939,-98.939,-96.939,-95.939,-94.939,-95.939,-96.939,-97.939,-97.939,-98.939,-97.939,-97.939,-97.939,-97.939,-98.939,-98.939,-97.939,-94.939,-91.939,-90.939,-92.939,-94.939,-96.939,-96.939,-97.939,-99.939,-100.939,-99.939,-98.939,-97.939,-98.939,-100.939,-100.939,-99.939,-99.939,-98.939,-97.939,-94.939,-91.939,-90.939,-91.939,-91.939,-92.939,-91.939,-89.939,-93.939,-92.939,-76.939,-68.939,-67.939,-85.939,-92.939,-91.939,-93.939,-93.939,-91.939,-94.939,-98.939,-84.939,-60.939003,-26.939003,-21.939003,-19.939003,-47.939003,-48.939003,-50.939003,-53.939003,-55.939003,-60.939003,-67.939,-71.939,-70.939,-72.939,-76.939,-80.939,-76.939,-63.939003,-47.939003,-32.939003,-23.939003,-18.939003,-17.939003,-22.939003,-28.939003,-35.939003,-37.939003,-38.939003,-42.939003,-51.939003,-59.939003,-62.939003,-64.939,-66.939,-71.939,-75.939,-76.939,-80.939,-87.939,-97.939,-100.939,-96.939,-100.939,-89.939,-43.939003,-34.939003,-44.939003,-83.939,-103.939,-102.939,-101.939,-101.939,-103.939,-102.939,-101.939,-100.939,-100.939,-100.939,-98.939,-97.939,-96.939,-96.939,-96.939,-95.939,-94.939,-93.939,-93.939,-91.939,-89.939,-87.939,-84.939,-84.939,-83.939,-80.939,-76.939,-73.939,-74.939,-69.939,-61.939003,-58.939003,-56.939003,-56.939003,-29.939003,-5.939003,1.060997,-4.939003,-16.939003,-20.939003,-25.939003,-32.939003,-27.939003,-20.939003,-14.939003,-12.939003,-12.939003,-13.939003,-14.939003,-15.939003,-13.939003,-12.939003,-12.939003,-15.939003,-18.939003,-20.939003,-24.939003,-29.939003,-32.939003,-35.939003,-38.939003,-43.939003,-50.939003,-53.939003,-58.939003,-64.939,-68.939,-71.939,-75.939,-82.939,-89.939,-95.939,-99.939,-98.939,-100.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-103.939,-103.939,-102.939,-101.939,-100.939,-100.939,-98.939,-96.939,-94.939,-91.939,-90.939,-90.939,-90.939,-79.939,-40.939003,15.060997,70.061,55.060997,49.060997,53.060997,-16.939003,-71.939,-76.939,-72.939,-65.939,-59.939003,-55.939003,-53.939003,-51.939003,-47.939003,-43.939003,-37.939003,-32.939003,-31.939003,-30.939003,-26.939003,-19.939003,-13.939003,-11.939003,-12.939003,-13.939003,-10.939003,-9.939003,-8.939003,-15.939003,-18.939003,-12.939003,-12.939003,-17.939003,-19.939003,-22.939003,-26.939003,-29.939003,-33.939003,-40.939003,-47.939003,-54.939003,-59.939003,-63.939003,-67.939,-70.939,-75.939,-80.939,-84.939,-88.939,-93.939,-96.939,-96.939,-97.939,-97.939,-97.939,-96.939,-95.939,-99.939,-101.939,-102.939,-57.939003,-24.939003,-28.939003,-15.939003,6.060997,47.060997,65.061,59.060997,57.060997,55.060997,53.060997,52.060997,50.060997,49.060997,51.060997,58.060997,44.060997,25.060997,0.06099701,6.060997,25.060997,26.060997,34.060997,49.060997,40.060997,29.060997,22.060997,17.060997,14.060997,11.060997,8.060997,6.060997,0.06099701,-4.939003,-2.939003,-5.939003,-10.939003,-15.939003,-19.939003,-20.939003,-18.939003,-19.939003,-28.939003,-32.939003,-33.939003,-34.939003,-38.939003,-43.939003,-26.939003,-6.939003,14.060997,27.060997,33.060997,24.060997,9.060997,-12.939003,-39.939003,-52.939003,-28.939003,-31.939003,-45.939003,-49.939003,-53.939003,-56.939003,-62.939003,-68.939,-69.939,-71.939,-74.939,-83.939,-88.939,-91.939,-96.939,-101.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-98.939,-97.939,-97.939,-95.939,-66.939,-25.939003,24.060997,38.060997,42.060997,28.060997,22.060997,13.060997,-45.939003,-78.939,-86.939,-83.939,-79.939,-80.939,-73.939,-58.939003,-6.939003,11.060997,-4.939003,-8.939003,-17.939003,-45.939003,-49.939003,-44.939003,-42.939003,-36.939003,-30.939003,-28.939003,-26.939003,-27.939003,-24.939003,-18.939003,-15.939003,-15.939003,-15.939003,-15.939003,-16.939003,-15.939003,-15.939003,-14.939003,-15.939003,-17.939003,-18.939003,-17.939003,-18.939003,-21.939003,-22.939003,-21.939003,-24.939003,-27.939003,-29.939003,-35.939003,-40.939003,-43.939003,-44.939003,-46.939003,-61.939003,-68.939,-69.939,-66.939,-67.939,-75.939,-83.939,-90.939,-89.939,-90.939,-92.939,-98.939,-87.939,-34.939003,25.060997,83.061,82.061,52.060997,-6.939003,-62.939003,-100.939,-98.939,-97.939,-96.939,-95.939,-94.939,-94.939,-94.939,-93.939,-92.939,-91.939,-90.939,-89.939,-89.939,-88.939,-87.939,-86.939,-83.939,-82.939,-83.939,-80.939,-76.939,-71.939,-72.939,-72.939,-69.939,-60.939003,-51.939003,-52.939003,-50.939003,-45.939003,-42.939003,-40.939003,-39.939003,-33.939003,-25.939003,-29.939003,-29.939003,-25.939003,-29.939003,-34.939003,-36.939003,-33.939003,-28.939003,-35.939003,-43.939003,-50.939003,-49.939003,-48.939003,-45.939003,-45.939003,-44.939003,-42.939003,-38.939003,-32.939003,-25.939003,-21.939003,-23.939003,-23.939003,-22.939003,-12.939003,-1.939003,11.060997,8.060997,6.060997,6.060997,9.060997,13.060997,14.060997,17.060997,22.060997,23.060997,25.060997,27.060997,30.060997,32.060997,27.060997,27.060997,33.060997,38.060997,38.060997,29.060997,2.060997,-33.939003,-64.939,-85.939,-94.939,-93.939,-89.939,-84.939,-80.939,-79.939,-84.939,-85.939,-81.939,-79.939,-80.939,-81.939,-88.939,-97.939,-85.939,-82.939,-86.939,-88.939,-88.939,-86.939,-88.939,-92.939,-89.939,-85.939,-83.939,-83.939,-82.939,-81.939,-83.939,-85.939,-84.939,-81.939,-77.939,-72.939,-68.939,-67.939,-58.939003,-47.939003,-31.939003,-26.939003,-31.939003,-27.939003,-27.939003,-34.939003,-36.939003,-34.939003,-26.939003,-25.939003,-31.939003,-47.939003,-59.939003,-63.939003,-65.939,-66.939,-75.939,-78.939,-76.939,-82.939,-88.939,-91.939,-94.939,-96.939,-97.939,-99.939,-101.939,-100.939,-99.939,-98.939,-98.939,-96.939,-96.939,-95.939,-93.939,-93.939,-93.939,-92.939,-91.939,-90.939,-89.939,-89.939,-89.939,-87.939,-85.939,-81.939,-80.939,-77.939,-42.939003,-16.939003,1.060997,-10.939003,-27.939003,-53.939003,-59.939003,-57.939003,-55.939003,-53.939003,-51.939003,-47.939003,-44.939003,-44.939003,-41.939003,-39.939003,-41.939003,-37.939003,-26.939003,-32.939003,-37.939003,-38.939003,-38.939003,-38.939003,-38.939003,-40.939003,-44.939003,-40.939003,-38.939003,-42.939003,-49.939003,-56.939003,-49.939003,-46.939003,-47.939003,-50.939003,-50.939003,-45.939003,-49.939003,-54.939003,-53.939003,-37.939003,-7.939003,-9.939003,-16.939003,-31.939003,-23.939003,-6.939003,-1.939003,1.060997,1.060997,-1.939003,-7.939003,-22.939003,-6.939003,19.060997,19.060997,14.060997,4.060997,2.060997,-0.939003,-7.939003,-12.939003,-16.939003,-8.939003,-8.939003,-16.939003,-22.939003,-24.939003,-22.939003,-28.939003,-38.939003,-43.939003,-43.939003,-38.939003,-35.939003,-33.939003,-34.939003,-39.939003,-44.939003,-44.939003,-41.939003,-36.939003,-30.939003,-26.939003,-26.939003,-24.939003,-23.939003,-27.939003,-29.939003,-29.939003,-31.939003,-33.939003,-40.939003,-8.939003,34.060997,29.060997,21.060997,13.060997,12.060997,11.060997,4.060997,7.060997,15.060997,25.060997,32.060997,35.060997,30.060997,27.060997,25.060997,-2.939003,-34.939003,-0.939003,23.060997,38.060997,34.060997,31.060997,36.060997,35.060997,34.060997,34.060997,30.060997,22.060997,9.060997,1.060997,7.060997,4.060997,-0.939003,4.060997,5.060997,2.060997,-2.939003,-6.939003,-8.939003,-11.939003,-12.939003,-16.939003,-17.939003,-18.939003,-18.939003,-22.939003,-34.939003,-30.939003,-21.939003,-24.939003,-26.939003,-28.939003,-29.939003,-31.939003,-29.939003,-31.939003,-32.939003,-36.939003,-39.939003,-42.939003,-44.939003,-45.939003,-45.939003,-44.939003,-43.939003,-46.939003,-44.939003,-37.939003,-21.939003,-9.939003,-10.939003,-25.939003,-44.939003,-45.939003,-45.939003,-44.939003,-41.939003,-40.939003,-44.939003,-45.939003,-44.939003,-45.939003,-46.939003,-44.939003,-51.939003,-55.939003,-53.939003,-51.939003,-50.939003,-50.939003,-55.939003,-66.939,-79.939,-90.939,-94.939,-96.939,-95.939,-92.939,-90.939,-92.939,-92.939,-92.939,-93.939,-94.939,-95.939,-94.939,-94.939,-94.939,-97.939,-98.939,-96.939,-98.939,-99.939,-102.939,-100.939,-95.939,-93.939,-93.939,-93.939,-94.939,-96.939,-97.939,-96.939,-94.939,-93.939,-94.939,-98.939,-98.939,-97.939,-96.939,-95.939,-93.939,-88.939,-83.939,-77.939,-77.939,-78.939,-79.939,-77.939,-73.939,-80.939,-81.939,-63.939003,-58.939003,-62.939003,-75.939,-80.939,-78.939,-82.939,-84.939,-84.939,-86.939,-90.939,-76.939,-55.939003,-28.939003,-22.939003,-20.939003,-52.939003,-53.939003,-55.939003,-58.939003,-60.939003,-64.939,-68.939,-73.939,-71.939,-71.939,-69.939,-52.939003,-51.939003,-66.939,-69.939,-69.939,-65.939,-67.939,-71.939,-76.939,-81.939,-85.939,-87.939,-88.939,-89.939,-90.939,-91.939,-93.939,-94.939,-94.939,-96.939,-97.939,-97.939,-95.939,-94.939,-98.939,-100.939,-101.939,-101.939,-91.939,-52.939003,-35.939003,-33.939003,-80.939,-102.939,-101.939,-99.939,-99.939,-102.939,-100.939,-96.939,-90.939,-88.939,-89.939,-87.939,-84.939,-76.939,-74.939,-73.939,-69.939,-66.939,-64.939,-62.939003,-60.939003,-52.939003,-45.939003,-39.939003,-35.939003,-30.939003,-25.939003,-22.939003,-18.939003,-15.939003,-14.939003,-16.939003,-14.939003,-14.939003,-15.939003,-19.939003,-21.939003,-18.939003,-13.939003,-7.939003,-9.939003,-17.939003,-32.939003,-35.939003,-38.939003,-43.939003,-47.939003,-50.939003,-53.939003,-58.939003,-65.939,-68.939,-71.939,-74.939,-78.939,-81.939,-82.939,-83.939,-86.939,-87.939,-87.939,-86.939,-88.939,-89.939,-88.939,-90.939,-93.939,-94.939,-96.939,-96.939,-97.939,-98.939,-100.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-101.939,-102.939,-101.939,-99.939,-95.939,-92.939,-88.939,-84.939,-78.939,-72.939,-73.939,-70.939,-62.939003,-55.939003,-46.939003,-30.939003,-18.939003,-9.939003,-15.939003,-22.939003,-30.939003,-27.939003,-23.939003,-15.939003,-10.939003,-7.939003,-5.939003,-6.939003,-8.939003,-6.939003,-5.939003,-8.939003,-16.939003,-23.939003,-25.939003,-26.939003,-28.939003,-34.939003,-38.939003,-42.939003,-45.939003,-49.939003,-54.939003,-57.939003,-59.939003,-70.939,-80.939,-78.939,-80.939,-83.939,-84.939,-84.939,-85.939,-86.939,-87.939,-89.939,-90.939,-92.939,-93.939,-94.939,-95.939,-95.939,-96.939,-98.939,-99.939,-99.939,-101.939,-101.939,-101.939,-102.939,-100.939,-96.939,-96.939,-97.939,-96.939,-96.939,-97.939,-58.939003,-28.939003,-29.939003,-18.939003,-1.939003,27.060997,40.060997,34.060997,28.060997,23.060997,18.060997,14.060997,12.060997,7.060997,2.060997,-3.939003,-5.939003,-8.939003,-14.939003,-15.939003,-15.939003,-18.939003,-20.939003,-25.939003,-27.939003,-29.939003,-29.939003,-31.939003,-33.939003,-30.939003,-28.939003,-26.939003,-24.939003,-22.939003,-17.939003,-13.939003,-9.939003,-8.939003,-5.939003,-2.939003,-10.939003,-14.939003,-2.939003,5.060997,12.060997,10.060997,15.060997,28.060997,12.060997,1.060997,10.060997,22.060997,32.060997,24.060997,14.060997,0.06099701,-52.939003,-91.939,-86.939,-86.939,-90.939,-91.939,-91.939,-92.939,-94.939,-95.939,-94.939,-95.939,-96.939,-98.939,-100.939,-100.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-101.939,-99.939,-96.939,-92.939,-91.939,-91.939,-87.939,-83.939,-79.939,-79.939,-78.939,-55.939003,-23.939003,18.060997,31.060997,37.060997,30.060997,20.060997,4.060997,-44.939003,-59.939003,-38.939003,-25.939003,-18.939003,-18.939003,-18.939003,-20.939003,-31.939003,-40.939003,-47.939003,-50.939003,-46.939003,-28.939003,-19.939003,-13.939003,-13.939003,-15.939003,-17.939003,-20.939003,-24.939003,-31.939003,-34.939003,-37.939003,-40.939003,-44.939003,-48.939003,-51.939003,-54.939003,-58.939003,-61.939003,-63.939003,-73.939,-79.939,-82.939,-82.939,-83.939,-84.939,-84.939,-84.939,-85.939,-76.939,-59.939003,-72.939,-85.939,-88.939,-81.939,-72.939,-87.939,-95.939,-95.939,-95.939,-95.939,-97.939,-98.939,-100.939,-100.939,-99.939,-98.939,-101.939,-89.939,-42.939003,14.060997,67.061,54.060997,27.060997,-13.939003,-60.939003,-92.939,-82.939,-75.939,-72.939,-72.939,-69.939,-62.939003,-62.939003,-62.939003,-57.939003,-52.939003,-46.939003,-44.939003,-41.939003,-36.939003,-35.939003,-31.939003,-24.939003,-20.939003,-18.939003,-15.939003,-15.939003,-14.939003,-15.939003,-16.939003,-14.939003,-14.939003,-13.939003,-14.939003,-15.939003,-14.939003,-12.939003,-13.939003,-22.939003,-24.939003,-25.939003,-32.939003,-36.939003,-37.939003,-37.939003,-36.939003,-33.939003,-38.939003,-44.939003,-27.939003,-9.939003,6.060997,8.060997,11.060997,15.060997,18.060997,20.060997,26.060997,29.060997,29.060997,39.060997,38.060997,7.060997,1.060997,5.060997,20.060997,37.060997,58.060997,53.060997,46.060997,42.060997,40.060997,39.060997,36.060997,36.060997,39.060997,34.060997,28.060997,25.060997,23.060997,22.060997,10.060997,6.060997,8.060997,10.060997,-2.939003,-47.939003,-72.939,-87.939,-93.939,-93.939,-90.939,-91.939,-91.939,-86.939,-86.939,-89.939,-89.939,-90.939,-92.939,-91.939,-90.939,-89.939,-93.939,-98.939,-94.939,-91.939,-90.939,-95.939,-98.939,-94.939,-96.939,-101.939,-97.939,-97.939,-98.939,-98.939,-97.939,-94.939,-96.939,-99.939,-99.939,-96.939,-92.939,-82.939,-75.939,-79.939,-79.939,-78.939,-75.939,-75.939,-80.939,-82.939,-75.939,-46.939003,1.060997,50.060997,43.060997,8.060997,-56.939003,-79.939,-93.939,-94.939,-94.939,-94.939,-96.939,-97.939,-97.939,-98.939,-99.939,-100.939,-99.939,-98.939,-96.939,-94.939,-93.939,-90.939,-87.939,-83.939,-79.939,-75.939,-73.939,-68.939,-61.939003,-61.939003,-60.939003,-54.939003,-49.939003,-45.939003,-41.939003,-40.939003,-41.939003,-37.939003,-34.939003,-36.939003,-35.939003,-34.939003,-45.939003,-51.939003,-51.939003,-45.939003,-39.939003,-36.939003,-37.939003,-40.939003,-48.939003,-52.939003,-52.939003,-48.939003,-46.939003,-52.939003,-59.939003,-67.939,-73.939,-72.939,-65.939,-70.939,-77.939,-80.939,-82.939,-83.939,-83.939,-85.939,-86.939,-87.939,-87.939,-87.939,-84.939,-81.939,-54.939003,-39.939003,-39.939003,-38.939003,-37.939003,-33.939003,-34.939003,-39.939003,-47.939003,-36.939003,-6.939003,-7.939003,-15.939003,-29.939003,-28.939003,-21.939003,-21.939003,-18.939003,-14.939003,-17.939003,-25.939003,-43.939003,-39.939003,-29.939003,-38.939003,-44.939003,-46.939003,-49.939003,-51.939003,-51.939003,-54.939003,-58.939003,-59.939003,-59.939003,-58.939003,-61.939003,-62.939003,-57.939003,-46.939003,-33.939003,-34.939003,-32.939003,-29.939003,-30.939003,-30.939003,-32.939003,-35.939003,-39.939003,-42.939003,-41.939003,-38.939003,-28.939003,-21.939003,-22.939003,-23.939003,-21.939003,-23.939003,-24.939003,-24.939003,-26.939003,-31.939003,-41.939003,-36.939003,-26.939003,-35.939003,-39.939003,-38.939003,-39.939003,-40.939003,-40.939003,-41.939003,-39.939003,-29.939003,-26.939003,-29.939003,-24.939003,-21.939003,-22.939003,-28.939003,-34.939003,-30.939003,-25.939003,-18.939003,-18.939003,-17.939003,-13.939003,-10.939003,-7.939003,-3.939003,-7.939003,-18.939003,-6.939003,6.060997,14.060997,13.060997,11.060997,18.060997,22.060997,23.060997,22.060997,21.060997,19.060997,20.060997,22.060997,20.060997,20.060997,22.060997,14.060997,2.060997,-19.939003,-7.939003,17.060997,15.060997,15.060997,16.060997,16.060997,17.060997,18.060997,16.060997,13.060997,12.060997,11.060997,10.060997,9.060997,9.060997,8.060997,5.060997,1.060997,-1.939003,0.06099701,7.060997,17.060997,23.060997,21.060997,9.060997,-5.939003,-3.939003,-6.939003,-11.939003,-7.939003,-5.939003,-9.939003,-13.939003,-16.939003,-16.939003,-17.939003,-20.939003,-26.939003,-32.939003,-36.939003,-33.939003,-30.939003,-65.939,-86.939,-94.939,-97.939,-99.939,-99.939,-98.939,-97.939,-90.939,-90.939,-94.939,-94.939,-94.939,-94.939,-95.939,-96.939,-98.939,-98.939,-95.939,-98.939,-100.939,-98.939,-97.939,-97.939,-101.939,-101.939,-96.939,-98.939,-99.939,-98.939,-98.939,-98.939,-99.939,-98.939,-95.939,-95.939,-96.939,-99.939,-100.939,-98.939,-97.939,-95.939,-93.939,-91.939,-84.939,-67.939,-62.939003,-62.939003,-62.939003,-59.939003,-55.939003,-61.939003,-62.939003,-52.939003,-50.939003,-51.939003,-53.939003,-54.939003,-56.939003,-54.939003,-53.939003,-55.939003,-55.939003,-53.939003,-54.939003,-50.939003,-44.939003,-44.939003,-43.939003,-57.939003,-58.939003,-60.939003,-63.939003,-64.939,-63.939003,-66.939,-66.939,-55.939003,-33.939003,-8.939003,7.060997,-10.939003,-62.939003,-80.939,-89.939,-85.939,-88.939,-94.939,-97.939,-99.939,-100.939,-102.939,-103.939,-103.939,-101.939,-99.939,-99.939,-99.939,-100.939,-99.939,-97.939,-93.939,-88.939,-83.939,-85.939,-86.939,-86.939,-85.939,-77.939,-50.939003,-35.939003,-29.939003,-56.939003,-67.939,-63.939003,-61.939003,-61.939003,-62.939003,-60.939003,-56.939003,-52.939003,-49.939003,-46.939003,-45.939003,-43.939003,-38.939003,-38.939003,-40.939003,-38.939003,-36.939003,-34.939003,-35.939003,-36.939003,-36.939003,-35.939003,-33.939003,-33.939003,-31.939003,-31.939003,-34.939003,-35.939003,-34.939003,-37.939003,-43.939003,-40.939003,-41.939003,-46.939003,-25.939003,-1.939003,22.060997,30.060997,31.060997,34.060997,7.060997,-51.939003,-63.939003,-70.939,-74.939,-77.939,-79.939,-82.939,-85.939,-90.939,-93.939,-95.939,-97.939,-100.939,-101.939,-101.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-100.939,-100.939,-102.939,-101.939,-99.939,-96.939,-95.939,-93.939,-93.939,-92.939,-91.939,-89.939,-86.939,-82.939,-79.939,-77.939,-73.939,-69.939,-67.939,-64.939,-62.939003,-58.939003,-58.939003,-58.939003,-56.939003,-52.939003,-47.939003,-48.939003,-47.939003,-45.939003,-40.939003,-36.939003,-39.939003,-37.939003,-30.939003,-28.939003,-27.939003,-26.939003,-20.939003,-13.939003,-13.939003,-13.939003,-13.939003,-20.939003,-28.939003,-32.939003,-31.939003,-28.939003,-32.939003,-37.939003,-41.939003,-41.939003,-41.939003,-44.939003,-50.939003,-57.939003,-58.939003,-59.939003,-61.939003,-66.939,-71.939,-74.939,-76.939,-79.939,-83.939,-86.939,-87.939,-94.939,-100.939,-100.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-99.939,-99.939,-97.939,-95.939,-96.939,-95.939,-92.939,-89.939,-85.939,-82.939,-78.939,-76.939,-73.939,-70.939,-66.939,-64.939,-62.939003,-59.939003,-56.939003,-54.939003,-37.939003,-24.939003,-26.939003,-19.939003,-9.939003,-5.939003,-4.939003,-8.939003,-9.939003,-11.939003,-12.939003,-13.939003,-14.939003,-13.939003,-14.939003,-17.939003,-16.939003,-14.939003,-12.939003,-13.939003,-14.939003,-8.939003,-2.939003,1.060997,0.06099701,-1.939003,-4.939003,-3.939003,0.06099701,0.06099701,2.060997,5.060997,10.060997,14.060997,17.060997,19.060997,22.060997,27.060997,30.060997,32.060997,16.060997,7.060997,22.060997,36.060997,49.060997,44.060997,48.060997,63.060997,33.060997,8.060997,7.060997,19.060997,33.060997,27.060997,18.060997,6.060997,-54.939003,-101.939,-103.939,-102.939,-101.939,-99.939,-97.939,-97.939,-96.939,-95.939,-93.939,-91.939,-89.939,-88.939,-85.939,-83.939,-79.939,-75.939,-73.939,-71.939,-68.939,-66.939,-64.939,-61.939003,-58.939003,-57.939003,-55.939003,-52.939003,-50.939003,-46.939003,-45.939003,-46.939003,-43.939003,-40.939003,-40.939003,-41.939003,-40.939003,-35.939003,-17.939003,13.060997,28.060997,36.060997,31.060997,21.060997,5.060997,-40.939003,-58.939003,-45.939003,-32.939003,-24.939003,-30.939003,-34.939003,-36.939003,-20.939003,-15.939003,-21.939003,-15.939003,-16.939003,-37.939003,-45.939003,-47.939003,-48.939003,-50.939003,-52.939003,-55.939003,-58.939003,-63.939003,-66.939,-69.939,-71.939,-75.939,-78.939,-80.939,-82.939,-85.939,-87.939,-89.939,-96.939,-100.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-100.939,-89.939,-69.939,-76.939,-85.939,-89.939,-79.939,-67.939,-79.939,-85.939,-85.939,-82.939,-79.939,-74.939,-69.939,-64.939,-67.939,-67.939,-66.939,-62.939003,-53.939003,-33.939003,-11.939003,8.060997,-3.939003,-15.939003,-29.939003,-42.939003,-49.939003,-45.939003,-39.939003,-34.939003,-36.939003,-36.939003,-34.939003,-35.939003,-35.939003,-33.939003,-31.939003,-31.939003,-32.939003,-31.939003,-30.939003,-29.939003,-29.939003,-27.939003,-29.939003,-30.939003,-30.939003,-31.939003,-35.939003,-37.939003,-39.939003,-41.939003,-44.939003,-46.939003,-47.939003,-47.939003,-48.939003,-47.939003,-48.939003,-55.939003,-57.939003,-59.939003,-64.939,-66.939,-65.939,-44.939003,-27.939003,-28.939003,-43.939003,-59.939003,-24.939003,8.060997,39.060997,35.060997,32.060997,35.060997,36.060997,35.060997,35.060997,34.060997,33.060997,38.060997,33.060997,2.060997,-5.939003,-3.939003,6.060997,16.060997,27.060997,21.060997,14.060997,9.060997,4.060997,1.060997,-2.939003,-4.939003,-4.939003,-7.939003,-12.939003,-15.939003,-16.939003,-16.939003,-24.939003,-27.939003,-27.939003,-24.939003,-35.939003,-75.939,-92.939,-99.939,-98.939,-96.939,-92.939,-95.939,-97.939,-94.939,-94.939,-95.939,-96.939,-97.939,-98.939,-97.939,-97.939,-96.939,-97.939,-99.939,-95.939,-94.939,-96.939,-97.939,-97.939,-97.939,-100.939,-103.939,-101.939,-101.939,-102.939,-101.939,-101.939,-99.939,-101.939,-103.939,-103.939,-102.939,-100.939,-92.939,-87.939,-89.939,-89.939,-90.939,-92.939,-93.939,-94.939,-95.939,-85.939,-51.939003,-1.939003,46.060997,30.060997,-3.939003,-57.939003,-72.939,-79.939,-75.939,-75.939,-76.939,-71.939,-69.939,-69.939,-68.939,-68.939,-67.939,-64.939,-60.939003,-60.939003,-60.939003,-60.939003,-57.939003,-54.939003,-54.939003,-53.939003,-52.939003,-55.939003,-52.939003,-46.939003,-45.939003,-47.939003,-52.939003,-49.939003,-43.939003,-46.939003,-49.939003,-49.939003,-49.939003,-48.939003,-49.939003,-52.939003,-54.939003,-42.939003,-30.939003,-16.939003,-18.939003,-26.939003,-48.939003,-55.939003,-59.939003,-64.939,-66.939,-65.939,-59.939003,-56.939003,-61.939003,-64.939,-69.939,-71.939,-70.939,-67.939,-67.939,-68.939,-67.939,-68.939,-70.939,-72.939,-70.939,-64.939,-66.939,-68.939,-68.939,-66.939,-63.939003,-48.939003,-40.939003,-39.939003,-39.939003,-38.939003,-38.939003,-39.939003,-41.939003,-45.939003,-41.939003,-28.939003,-31.939003,-36.939003,-43.939003,-44.939003,-43.939003,-43.939003,-41.939003,-42.939003,-40.939003,-41.939003,-51.939003,-49.939003,-44.939003,-48.939003,-50.939003,-50.939003,-49.939003,-46.939003,-44.939003,-42.939003,-41.939003,-39.939003,-38.939003,-36.939003,-35.939003,-32.939003,-25.939003,-31.939003,-42.939003,-45.939003,-44.939003,-40.939003,-42.939003,-43.939003,-44.939003,-44.939003,-41.939003,-41.939003,-40.939003,-38.939003,-27.939003,-20.939003,-21.939003,-22.939003,-22.939003,-21.939003,-20.939003,-20.939003,-25.939003,-31.939003,-37.939003,-24.939003,-8.939003,-21.939003,-28.939003,-28.939003,-33.939003,-36.939003,-37.939003,-37.939003,-36.939003,-31.939003,-29.939003,-30.939003,-30.939003,-30.939003,-31.939003,-34.939003,-37.939003,-39.939003,-39.939003,-38.939003,-37.939003,-37.939003,-37.939003,-34.939003,-29.939003,-29.939003,-31.939003,-34.939003,-28.939003,-22.939003,-17.939003,-16.939003,-17.939003,-17.939003,-15.939003,-11.939003,-11.939003,-12.939003,-14.939003,-12.939003,-9.939003,-9.939003,-8.939003,-6.939003,-10.939003,-18.939003,-32.939003,-23.939003,-6.939003,-5.939003,-4.939003,-2.939003,-3.939003,-3.939003,0.06099701,-0.939003,-0.939003,-3.939003,-4.939003,-4.939003,-2.939003,-0.939003,-2.939003,-3.939003,-4.939003,-4.939003,-6.939003,-7.939003,-3.939003,-0.939003,-0.939003,-0.939003,-0.939003,1.060997,-0.939003,-3.939003,-0.939003,1.060997,-1.939003,-3.939003,-4.939003,-2.939003,-2.939003,-4.939003,-8.939003,-9.939003,-5.939003,-12.939003,-27.939003,-73.939,-98.939,-100.939,-102.939,-102.939,-99.939,-99.939,-99.939,-96.939,-96.939,-98.939,-95.939,-93.939,-93.939,-93.939,-92.939,-95.939,-96.939,-95.939,-97.939,-98.939,-98.939,-96.939,-94.939,-98.939,-99.939,-96.939,-97.939,-99.939,-100.939,-99.939,-97.939,-99.939,-100.939,-98.939,-96.939,-96.939,-98.939,-99.939,-98.939,-96.939,-94.939,-93.939,-95.939,-92.939,-78.939,-63.939003,-47.939003,-47.939003,-46.939003,-45.939003,-48.939003,-48.939003,-45.939003,-46.939003,-49.939003,-47.939003,-45.939003,-44.939003,-43.939003,-44.939003,-49.939003,-50.939003,-49.939003,-48.939003,-45.939003,-40.939003,-37.939003,-35.939003,-62.939003,-63.939003,-65.939,-67.939,-65.939,-60.939003,-59.939003,-51.939003,-27.939003,17.060997,66.061,72.061,32.060997,-55.939003,-85.939,-100.939,-93.939,-94.939,-99.939,-100.939,-99.939,-97.939,-98.939,-99.939,-99.939,-96.939,-92.939,-90.939,-90.939,-91.939,-88.939,-84.939,-76.939,-68.939,-62.939003,-65.939,-65.939,-60.939003,-59.939003,-55.939003,-42.939003,-33.939003,-27.939003,-26.939003,-22.939003,-16.939003,-16.939003,-15.939003,-13.939003,-12.939003,-11.939003,-11.939003,-8.939003,-0.939003,-0.939003,-1.939003,-3.939003,-7.939003,-12.939003,-14.939003,-14.939003,-13.939003,-17.939003,-23.939003,-35.939003,-42.939003,-48.939003,-52.939003,-57.939003,-62.939003,-71.939,-79.939,-82.939,-89.939,-95.939,-92.939,-94.939,-102.939,-37.939003,29.060997,82.061,89.061,77.061,87.061,37.060997,-74.939,-95.939,-103.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-100.939,-99.939,-99.939,-98.939,-99.939,-99.939,-98.939,-97.939,-96.939,-97.939,-93.939,-88.939,-83.939,-79.939,-77.939,-75.939,-72.939,-70.939,-65.939,-60.939003,-52.939003,-45.939003,-41.939003,-33.939003,-26.939003,-21.939003,-16.939003,-12.939003,-8.939003,-6.939003,-5.939003,-3.939003,0.06099701,4.060997,-1.939003,-5.939003,-5.939003,-4.939003,-4.939003,-7.939003,-8.939003,-7.939003,-13.939003,-20.939003,-29.939003,-9.939003,18.060997,19.060997,28.060997,45.060997,-5.939003,-54.939003,-80.939,-83.939,-78.939,-89.939,-96.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-98.939,-98.939,-96.939,-95.939,-92.939,-89.939,-88.939,-84.939,-79.939,-80.939,-79.939,-72.939,-65.939,-59.939003,-51.939003,-43.939003,-39.939003,-34.939003,-30.939003,-27.939003,-22.939003,-17.939003,-15.939003,-10.939003,-4.939003,-11.939003,-17.939003,-20.939003,-18.939003,-17.939003,-36.939003,-46.939003,-47.939003,-42.939003,-37.939003,-33.939003,-30.939003,-27.939003,-19.939003,-11.939003,-5.939003,-5.939003,-5.939003,-4.939003,-0.939003,5.060997,21.060997,42.060997,67.061,63.060997,56.060997,48.060997,53.060997,62.060997,58.060997,57.060997,60.060997,65.061,69.061,67.061,65.061,64.061,69.061,71.061,69.061,48.060997,33.060997,41.060997,58.060997,76.061,67.061,66.061,75.061,41.060997,13.060997,5.060997,18.060997,35.060997,31.060997,22.060997,8.060997,-50.939003,-96.939,-95.939,-94.939,-92.939,-87.939,-84.939,-84.939,-81.939,-78.939,-75.939,-71.939,-67.939,-64.939,-59.939003,-54.939003,-45.939003,-38.939003,-34.939003,-28.939003,-22.939003,-20.939003,-15.939003,-9.939003,-4.939003,-0.939003,-0.939003,1.060997,1.060997,5.060997,5.060997,1.060997,2.060997,2.060997,-4.939003,-5.939003,-4.939003,-17.939003,-12.939003,11.060997,28.060997,38.060997,33.060997,23.060997,8.060997,-36.939003,-64.939,-75.939,-65.939,-60.939003,-71.939,-77.939,-72.939,3.060997,38.060997,31.060997,46.060997,34.060997,-54.939003,-90.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-97.939,-97.939,-96.939,-91.939,-82.939,-67.939,-64.939,-65.939,-67.939,-59.939003,-48.939003,-54.939003,-57.939003,-58.939003,-53.939003,-45.939003,-35.939003,-24.939003,-14.939003,-21.939003,-24.939003,-23.939003,-14.939003,-10.939003,-21.939003,-37.939003,-55.939003,-58.939003,-55.939003,-45.939003,-21.939003,-4.939003,-8.939003,-6.939003,-1.939003,-5.939003,-10.939003,-15.939003,-18.939003,-19.939003,-20.939003,-25.939003,-33.939003,-37.939003,-41.939003,-45.939003,-47.939003,-50.939003,-58.939003,-66.939,-74.939,-75.939,-79.939,-85.939,-89.939,-91.939,-96.939,-100.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-99.939,-94.939,-49.939003,-16.939003,-25.939003,-47.939003,-69.939,-26.939003,14.060997,53.060997,41.060997,29.060997,30.060997,28.060997,24.060997,15.060997,9.060997,8.060997,6.060997,-0.939003,-18.939003,-25.939003,-28.939003,-25.939003,-26.939003,-31.939003,-37.939003,-41.939003,-45.939003,-50.939003,-56.939003,-57.939003,-60.939003,-63.939003,-62.939003,-62.939003,-62.939003,-59.939003,-57.939003,-57.939003,-57.939003,-57.939003,-51.939003,-54.939003,-74.939,-85.939,-90.939,-91.939,-91.939,-90.939,-94.939,-97.939,-97.939,-96.939,-94.939,-98.939,-99.939,-99.939,-99.939,-99.939,-100.939,-99.939,-98.939,-92.939,-92.939,-97.939,-92.939,-90.939,-97.939,-100.939,-101.939,-101.939,-100.939,-99.939,-97.939,-97.939,-99.939,-100.939,-100.939,-101.939,-101.939,-100.939,-99.939,-97.939,-97.939,-93.939,-90.939,-97.939,-96.939,-89.939,-85.939,-74.939,-52.939003,-22.939003,3.060997,-14.939003,-32.939003,-48.939003,-51.939003,-48.939003,-40.939003,-40.939003,-43.939003,-33.939003,-28.939003,-27.939003,-26.939003,-26.939003,-24.939003,-20.939003,-16.939003,-19.939003,-22.939003,-25.939003,-22.939003,-20.939003,-25.939003,-29.939003,-34.939003,-42.939003,-44.939003,-40.939003,-39.939003,-44.939003,-63.939003,-63.939003,-56.939003,-69.939,-76.939,-75.939,-80.939,-82.939,-79.939,-87.939,-92.939,-35.939003,10.060997,46.060997,29.060997,-5.939003,-66.939,-84.939,-84.939,-83.939,-81.939,-78.939,-71.939,-65.939,-65.939,-63.939003,-60.939003,-56.939003,-54.939003,-55.939003,-48.939003,-41.939003,-35.939003,-34.939003,-36.939003,-40.939003,-35.939003,-22.939003,-23.939003,-26.939003,-28.939003,-28.939003,-31.939003,-40.939003,-44.939003,-43.939003,-43.939003,-44.939003,-48.939003,-49.939003,-48.939003,-43.939003,-46.939003,-54.939003,-59.939003,-62.939003,-59.939003,-60.939003,-62.939003,-59.939003,-60.939003,-66.939,-59.939003,-52.939003,-51.939003,-45.939003,-39.939003,-36.939003,-33.939003,-32.939003,-25.939003,-19.939003,-16.939003,-10.939003,-3.939003,5.060997,7.060997,6.060997,11.060997,17.060997,25.060997,-8.939003,-54.939003,-62.939003,-63.939003,-58.939003,-57.939003,-58.939003,-59.939003,-55.939003,-47.939003,-41.939003,-38.939003,-37.939003,-28.939003,-21.939003,-21.939003,-23.939003,-24.939003,-19.939003,-17.939003,-17.939003,-25.939003,-32.939003,-30.939003,0.06099701,39.060997,22.060997,10.060997,5.060997,-2.939003,-9.939003,-14.939003,-12.939003,-7.939003,-8.939003,-7.939003,-4.939003,-12.939003,-19.939003,-21.939003,-30.939003,-40.939003,-34.939003,-32.939003,-35.939003,-36.939003,-37.939003,-41.939003,-40.939003,-37.939003,-42.939003,-40.939003,-33.939003,-50.939003,-63.939003,-62.939003,-61.939003,-61.939003,-70.939,-72.939,-67.939,-66.939,-67.939,-69.939,-68.939,-65.939,-64.939,-62.939003,-62.939003,-58.939003,-55.939003,-54.939003,-55.939003,-56.939003,-51.939003,-48.939003,-47.939003,-49.939003,-50.939003,-45.939003,-42.939003,-41.939003,-45.939003,-48.939003,-47.939003,-41.939003,-38.939003,-39.939003,-36.939003,-33.939003,-30.939003,-34.939003,-42.939003,-43.939003,-42.939003,-40.939003,-28.939003,-13.939003,-12.939003,-11.939003,-10.939003,-8.939003,-7.939003,-8.939003,-7.939003,-5.939003,-0.939003,2.060997,2.060997,1.060997,6.060997,21.060997,2.060997,-34.939003,-79.939,-101.939,-98.939,-101.939,-102.939,-97.939,-99.939,-102.939,-103.939,-102.939,-100.939,-92.939,-87.939,-89.939,-88.939,-85.939,-88.939,-92.939,-94.939,-95.939,-95.939,-97.939,-95.939,-92.939,-95.939,-96.939,-95.939,-95.939,-96.939,-100.939,-98.939,-95.939,-100.939,-102.939,-101.939,-98.939,-96.939,-97.939,-97.939,-96.939,-95.939,-93.939,-92.939,-98.939,-101.939,-96.939,-68.939,-33.939003,-32.939003,-34.939003,-38.939003,-37.939003,-36.939003,-38.939003,-44.939003,-49.939003,-46.939003,-42.939003,-36.939003,-37.939003,-41.939003,-50.939003,-53.939003,-55.939003,-47.939003,-38.939003,-27.939003,-20.939003,-15.939003,-63.939003,-65.939,-67.939,-65.939,-58.939003,-48.939003,-16.939003,16.060997,41.060997,56.060997,63.060997,55.060997,18.060997,-46.939003,-74.939,-89.939,-83.939,-77.939,-71.939,-73.939,-70.939,-62.939003,-61.939003,-60.939003,-59.939003,-54.939003,-47.939003,-41.939003,-39.939003,-40.939003,-37.939003,-32.939003,-24.939003,-20.939003,-19.939003,-22.939003,-24.939003,-22.939003,-23.939003,-24.939003,-23.939003,-24.939003,-26.939003,-27.939003,-27.939003,-26.939003,-33.939003,-38.939003,-39.939003,-42.939003,-46.939003,-54.939003,-56.939003,-53.939003,-55.939003,-58.939003,-62.939003,-64.939,-67.939,-70.939,-71.939,-73.939,-74.939,-76.939,-80.939,-83.939,-85.939,-86.939,-88.939,-89.939,-92.939,-95.939,-96.939,-98.939,-100.939,-99.939,-100.939,-103.939,-50.939003,11.060997,83.061,88.061,69.061,86.061,44.060997,-57.939003,-87.939,-103.939,-103.939,-101.939,-100.939,-102.939,-101.939,-96.939,-95.939,-93.939,-89.939,-85.939,-82.939,-78.939,-74.939,-70.939,-66.939,-63.939003,-60.939003,-56.939003,-50.939003,-46.939003,-41.939003,-36.939003,-32.939003,-29.939003,-24.939003,-22.939003,-20.939003,-19.939003,-17.939003,-17.939003,-18.939003,-20.939003,-19.939003,-20.939003,-20.939003,-20.939003,-22.939003,-26.939003,-31.939003,-34.939003,-35.939003,-36.939003,-38.939003,-42.939003,-48.939003,-54.939003,-58.939003,-62.939003,-66.939,-69.939,-70.939,-70.939,-70.939,-69.939,-72.939,-72.939,-65.939,-1.939003,75.061,58.060997,60.060997,82.061,11.060997,-55.939003,-87.939,-96.939,-93.939,-98.939,-101.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-97.939,-94.939,-91.939,-88.939,-85.939,-78.939,-71.939,-70.939,-67.939,-62.939003,-58.939003,-54.939003,-48.939003,-45.939003,-41.939003,-38.939003,-35.939003,-34.939003,-31.939003,-27.939003,-21.939003,-22.939003,-22.939003,-18.939003,-18.939003,-19.939003,-19.939003,-18.939003,-18.939003,-20.939003,-22.939003,-23.939003,-27.939003,-32.939003,-34.939003,-37.939003,-42.939003,-28.939003,-18.939003,-20.939003,-19.939003,-15.939003,5.060997,16.060997,19.060997,22.060997,25.060997,30.060997,32.060997,33.060997,34.060997,39.060997,49.060997,49.060997,37.060997,-1.939003,8.060997,36.060997,36.060997,50.060997,80.061,75.061,66.061,56.060997,58.060997,65.061,63.060997,64.061,66.061,66.061,66.061,66.061,66.061,63.060997,61.060997,60.060997,60.060997,43.060997,27.060997,22.060997,33.060997,47.060997,36.060997,30.060997,31.060997,14.060997,2.060997,9.060997,23.060997,36.060997,29.060997,17.060997,0.06099701,-40.939003,-61.939003,-30.939003,-25.939003,-31.939003,-27.939003,-24.939003,-22.939003,-22.939003,-21.939003,-17.939003,-18.939003,-21.939003,-20.939003,-20.939003,-20.939003,-18.939003,-19.939003,-21.939003,-21.939003,-21.939003,-26.939003,-29.939003,-31.939003,-32.939003,-33.939003,-37.939003,-40.939003,-45.939003,-47.939003,-49.939003,-52.939003,-56.939003,-60.939003,-66.939,-69.939,-69.939,-58.939003,-29.939003,20.060997,33.060997,39.060997,37.060997,28.060997,12.060997,-33.939003,-67.939,-92.939,-90.939,-89.939,-92.939,-94.939,-85.939,14.060997,59.060997,48.060997,59.060997,41.060997,-53.939003,-91.939,-103.939,-103.939,-101.939,-97.939,-95.939,-93.939,-90.939,-87.939,-85.939,-83.939,-82.939,-80.939,-75.939,-71.939,-67.939,-62.939003,-58.939003,-55.939003,-51.939003,-47.939003,-41.939003,-37.939003,-37.939003,-36.939003,-34.939003,-29.939003,-28.939003,-30.939003,-26.939003,-22.939003,-19.939003,-20.939003,-23.939003,-17.939003,-15.939003,-17.939003,-19.939003,-20.939003,-18.939003,-18.939003,-19.939003,-28.939003,-33.939003,-32.939003,-39.939003,-41.939003,-27.939003,-8.939003,8.060997,9.060997,1.060997,-16.939003,-45.939003,-65.939,-60.939003,-63.939003,-69.939,-70.939,-72.939,-74.939,-75.939,-75.939,-75.939,-77.939,-80.939,-81.939,-82.939,-84.939,-84.939,-85.939,-87.939,-90.939,-93.939,-94.939,-95.939,-97.939,-96.939,-96.939,-100.939,-102.939,-103.939,-102.939,-99.939,-95.939,-97.939,-97.939,-91.939,-84.939,-77.939,-78.939,-76.939,-69.939,-42.939003,-23.939003,-31.939003,-39.939003,-46.939003,-35.939003,-24.939003,-13.939003,-22.939003,-30.939003,-31.939003,-33.939003,-36.939003,-37.939003,-39.939003,-41.939003,-42.939003,-42.939003,-40.939003,-38.939003,-37.939003,-40.939003,-41.939003,-42.939003,-42.939003,-40.939003,-36.939003,-35.939003,-37.939003,-32.939003,-31.939003,-33.939003,-27.939003,-21.939003,-17.939003,-13.939003,-10.939003,-10.939003,-8.939003,-5.939003,3.060997,5.060997,-9.939003,-37.939003,-65.939,-46.939003,-36.939003,-35.939003,-39.939003,-41.939003,-41.939003,-43.939003,-47.939003,-55.939003,-62.939003,-66.939,-70.939,-72.939,-72.939,-74.939,-75.939,-77.939,-67.939,-46.939003,-51.939003,-63.939003,-85.939,-88.939,-83.939,-82.939,-81.939,-80.939,-80.939,-79.939,-76.939,-79.939,-83.939,-80.939,-76.939,-71.939,-69.939,-68.939,-72.939,-73.939,-74.939,-84.939,-75.939,-49.939003,-42.939003,-38.939003,-41.939003,-41.939003,-42.939003,-39.939003,-41.939003,-47.939003,-47.939003,-47.939003,-47.939003,-49.939003,-52.939003,-52.939003,-52.939003,-52.939003,-59.939003,-63.939003,-61.939003,-60.939003,-59.939003,-60.939003,-60.939003,-60.939003,-61.939003,-61.939003,-64.939,-62.939003,-59.939003,-61.939003,-59.939003,-55.939003,-51.939003,-49.939003,-52.939003,-49.939003,-45.939003,-52.939003,-54.939003,-52.939003,-49.939003,-45.939003,-41.939003,-42.939003,-42.939003,-10.939003,13.060997,29.060997,16.060997,-4.939003,-34.939003,-41.939003,-37.939003,-38.939003,-39.939003,-39.939003,-38.939003,-38.939003,-39.939003,-41.939003,-43.939003,-44.939003,-45.939003,-46.939003,-47.939003,-48.939003,-44.939003,-45.939003,-46.939003,-47.939003,-45.939003,-39.939003,-38.939003,-37.939003,-34.939003,-32.939003,-32.939003,-39.939003,-39.939003,-34.939003,-35.939003,-35.939003,-33.939003,-30.939003,-26.939003,-26.939003,-25.939003,-22.939003,-26.939003,-31.939003,-32.939003,-28.939003,-21.939003,-20.939003,-17.939003,-13.939003,-18.939003,-25.939003,-34.939003,-17.939003,6.060997,6.060997,4.060997,1.060997,1.060997,3.060997,3.060997,2.060997,-0.939003,4.060997,1.060997,-6.939003,-10.939003,-11.939003,-8.939003,-26.939003,-51.939003,-50.939003,-49.939003,-48.939003,-42.939003,-39.939003,-42.939003,-44.939003,-44.939003,-39.939003,-36.939003,-34.939003,-26.939003,-20.939003,-21.939003,-21.939003,-21.939003,-22.939003,-22.939003,-18.939003,-26.939003,-32.939003,-31.939003,5.060997,49.060997,27.060997,18.060997,24.060997,18.060997,10.060997,2.060997,12.060997,26.060997,24.060997,25.060997,29.060997,26.060997,21.060997,13.060997,-7.939003,-26.939003,6.060997,24.060997,25.060997,24.060997,25.060997,27.060997,22.060997,17.060997,26.060997,25.060997,12.060997,-45.939003,-89.939,-86.939,-85.939,-84.939,-91.939,-92.939,-89.939,-88.939,-88.939,-89.939,-89.939,-88.939,-87.939,-87.939,-89.939,-75.939,-64.939,-64.939,-72.939,-83.939,-80.939,-79.939,-80.939,-82.939,-81.939,-74.939,-74.939,-77.939,-76.939,-76.939,-75.939,-75.939,-75.939,-74.939,-71.939,-70.939,-74.939,-66.939,-44.939003,-32.939003,-24.939003,-22.939003,-38.939003,-59.939003,-58.939003,-56.939003,-52.939003,-53.939003,-53.939003,-52.939003,-52.939003,-52.939003,-49.939003,-49.939003,-51.939003,-50.939003,-46.939003,-35.939003,-50.939003,-75.939,-90.939,-97.939,-97.939,-100.939,-101.939,-95.939,-96.939,-99.939,-102.939,-94.939,-75.939,-64.939,-60.939003,-70.939,-73.939,-74.939,-83.939,-88.939,-89.939,-94.939,-96.939,-98.939,-95.939,-92.939,-95.939,-98.939,-100.939,-98.939,-98.939,-101.939,-100.939,-100.939,-102.939,-102.939,-100.939,-99.939,-99.939,-100.939,-101.939,-101.939,-100.939,-97.939,-94.939,-99.939,-100.939,-93.939,-58.939003,-15.939003,-16.939003,-19.939003,-23.939003,-23.939003,-25.939003,-31.939003,-34.939003,-33.939003,-22.939003,-17.939003,-17.939003,-18.939003,-20.939003,-25.939003,-24.939003,-21.939003,-22.939003,-22.939003,-22.939003,-18.939003,-14.939003,-66.939,-68.939,-69.939,-63.939003,-53.939003,-38.939003,-6.939003,23.060997,37.060997,34.060997,23.060997,12.060997,-8.939003,-39.939003,-51.939003,-58.939003,-53.939003,-47.939003,-41.939003,-43.939003,-42.939003,-38.939003,-36.939003,-35.939003,-35.939003,-32.939003,-29.939003,-25.939003,-25.939003,-28.939003,-28.939003,-25.939003,-20.939003,-19.939003,-22.939003,-26.939003,-30.939003,-33.939003,-35.939003,-36.939003,-35.939003,-29.939003,-23.939003,-40.939003,-50.939003,-53.939003,-60.939003,-66.939,-68.939,-71.939,-75.939,-84.939,-87.939,-86.939,-89.939,-92.939,-95.939,-96.939,-97.939,-99.939,-100.939,-102.939,-102.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-100.939,-99.939,-97.939,-95.939,-93.939,-93.939,-55.939003,-6.939003,60.060997,63.060997,43.060997,56.060997,27.060997,-41.939003,-64.939,-75.939,-75.939,-71.939,-67.939,-69.939,-67.939,-62.939003,-59.939003,-57.939003,-53.939003,-50.939003,-48.939003,-46.939003,-43.939003,-40.939003,-37.939003,-36.939003,-34.939003,-31.939003,-28.939003,-26.939003,-23.939003,-19.939003,-19.939003,-19.939003,-18.939003,-19.939003,-20.939003,-23.939003,-23.939003,-24.939003,-27.939003,-32.939003,-34.939003,-37.939003,-39.939003,-42.939003,-46.939003,-52.939003,-58.939003,-62.939003,-65.939,-67.939,-69.939,-75.939,-82.939,-89.939,-91.939,-94.939,-99.939,-102.939,-103.939,-102.939,-101.939,-101.939,-102.939,-99.939,-85.939,-2.939003,96.061,73.061,72.061,94.061,19.060997,-52.939003,-88.939,-95.939,-89.939,-89.939,-88.939,-87.939,-85.939,-82.939,-81.939,-79.939,-77.939,-75.939,-73.939,-70.939,-69.939,-67.939,-67.939,-63.939003,-59.939003,-55.939003,-52.939003,-50.939003,-43.939003,-38.939003,-38.939003,-36.939003,-33.939003,-31.939003,-29.939003,-25.939003,-24.939003,-23.939003,-22.939003,-23.939003,-24.939003,-23.939003,-22.939003,-21.939003,-24.939003,-26.939003,-24.939003,-27.939003,-31.939003,-34.939003,-36.939003,-38.939003,-41.939003,-45.939003,-47.939003,-52.939003,-59.939003,-62.939003,-67.939,-74.939,-48.939003,-28.939003,-26.939003,-21.939003,-10.939003,29.060997,51.060997,57.060997,57.060997,58.060997,62.060997,63.060997,65.061,62.060997,64.061,71.061,71.061,54.060997,0.06099701,10.060997,43.060997,34.060997,43.060997,73.061,65.061,55.060997,44.060997,43.060997,46.060997,42.060997,41.060997,42.060997,37.060997,34.060997,34.060997,34.060997,32.060997,27.060997,23.060997,21.060997,13.060997,5.060997,-1.939003,1.060997,7.060997,0.06099701,-4.939003,-7.939003,-10.939003,-7.939003,12.060997,26.060997,35.060997,27.060997,15.060997,-0.939003,-34.939003,-49.939003,-16.939003,-12.939003,-19.939003,-19.939003,-18.939003,-18.939003,-22.939003,-24.939003,-21.939003,-24.939003,-30.939003,-30.939003,-32.939003,-34.939003,-36.939003,-38.939003,-42.939003,-45.939003,-47.939003,-52.939003,-57.939003,-61.939003,-64.939,-67.939,-70.939,-74.939,-79.939,-82.939,-85.939,-86.939,-91.939,-95.939,-99.939,-102.939,-101.939,-80.939,-38.939003,23.060997,33.060997,37.060997,38.060997,30.060997,16.060997,-28.939003,-65.939,-96.939,-93.939,-88.939,-86.939,-85.939,-76.939,4.060997,38.060997,27.060997,30.060997,16.060997,-43.939003,-66.939,-70.939,-70.939,-67.939,-64.939,-61.939003,-59.939003,-55.939003,-53.939003,-52.939003,-51.939003,-50.939003,-48.939003,-45.939003,-41.939003,-40.939003,-36.939003,-32.939003,-30.939003,-28.939003,-28.939003,-23.939003,-21.939003,-22.939003,-22.939003,-22.939003,-21.939003,-22.939003,-25.939003,-26.939003,-26.939003,-27.939003,-27.939003,-28.939003,-28.939003,-28.939003,-29.939003,-34.939003,-38.939003,-38.939003,-42.939003,-47.939003,-54.939003,-58.939003,-57.939003,-68.939,-65.939,-31.939003,16.060997,63.060997,62.060997,41.060997,0.06099701,-57.939003,-97.939,-91.939,-95.939,-103.939,-102.939,-102.939,-101.939,-99.939,-96.939,-96.939,-95.939,-95.939,-93.939,-92.939,-90.939,-89.939,-88.939,-85.939,-83.939,-81.939,-80.939,-80.939,-80.939,-75.939,-71.939,-75.939,-74.939,-70.939,-69.939,-66.939,-61.939003,-61.939003,-61.939003,-55.939003,-50.939003,-43.939003,-44.939003,-43.939003,-42.939003,-34.939003,-29.939003,-34.939003,-34.939003,-32.939003,-34.939003,-35.939003,-37.939003,-41.939003,-43.939003,-43.939003,-42.939003,-42.939003,-39.939003,-37.939003,-38.939003,-34.939003,-33.939003,-34.939003,-32.939003,-29.939003,-26.939003,-20.939003,-12.939003,-9.939003,-6.939003,-3.939003,-2.939003,-1.939003,2.060997,4.060997,3.060997,8.060997,13.060997,16.060997,20.060997,22.060997,16.060997,17.060997,23.060997,28.060997,30.060997,24.060997,-7.939003,-44.939003,-35.939003,-27.939003,-20.939003,-18.939003,-17.939003,-14.939003,-14.939003,-17.939003,-23.939003,-29.939003,-33.939003,-33.939003,-32.939003,-29.939003,-29.939003,-30.939003,-34.939003,-24.939003,-0.939003,-7.939003,-21.939003,-44.939003,-49.939003,-48.939003,-43.939003,-40.939003,-39.939003,-41.939003,-42.939003,-39.939003,-43.939003,-49.939003,-49.939003,-47.939003,-44.939003,-46.939003,-51.939003,-61.939003,-69.939,-76.939,-83.939,-70.939,-39.939003,-37.939003,-37.939003,-34.939003,-31.939003,-28.939003,-21.939003,-26.939003,-43.939003,-46.939003,-47.939003,-47.939003,-47.939003,-49.939003,-53.939003,-54.939003,-53.939003,-59.939003,-61.939003,-58.939003,-57.939003,-58.939003,-55.939003,-53.939003,-53.939003,-55.939003,-55.939003,-56.939003,-53.939003,-49.939003,-50.939003,-49.939003,-45.939003,-43.939003,-40.939003,-37.939003,-35.939003,-35.939003,-39.939003,-39.939003,-37.939003,-33.939003,-30.939003,-28.939003,-27.939003,-26.939003,-15.939003,-9.939003,-7.939003,-13.939003,-19.939003,-24.939003,-24.939003,-23.939003,-22.939003,-22.939003,-23.939003,-23.939003,-24.939003,-26.939003,-28.939003,-29.939003,-32.939003,-32.939003,-30.939003,-34.939003,-37.939003,-36.939003,-37.939003,-37.939003,-37.939003,-36.939003,-34.939003,-34.939003,-34.939003,-31.939003,-30.939003,-31.939003,-33.939003,-31.939003,-25.939003,-28.939003,-28.939003,-25.939003,-23.939003,-21.939003,-24.939003,-19.939003,-5.939003,-10.939003,-17.939003,-21.939003,-15.939003,-5.939003,-7.939003,-5.939003,-1.939003,-11.939003,-23.939003,-33.939003,-19.939003,2.060997,-1.939003,-6.939003,-10.939003,-11.939003,-10.939003,-10.939003,-13.939003,-18.939003,-20.939003,-24.939003,-31.939003,-35.939003,-36.939003,-33.939003,-38.939003,-47.939003,-42.939003,-40.939003,-40.939003,-35.939003,-31.939003,-35.939003,-40.939003,-43.939003,-39.939003,-36.939003,-33.939003,-27.939003,-22.939003,-21.939003,-20.939003,-19.939003,-21.939003,-22.939003,-21.939003,-27.939003,-33.939003,-33.939003,-10.939003,18.060997,0.06099701,-5.939003,0.06099701,-1.939003,-5.939003,-12.939003,-3.939003,9.060997,9.060997,11.060997,14.060997,14.060997,13.060997,7.060997,-9.939003,-25.939003,5.060997,20.060997,23.060997,23.060997,25.060997,28.060997,24.060997,19.060997,32.060997,30.060997,14.060997,-45.939003,-90.939,-88.939,-88.939,-88.939,-89.939,-89.939,-89.939,-90.939,-91.939,-91.939,-91.939,-92.939,-92.939,-94.939,-98.939,-81.939,-67.939,-67.939,-80.939,-98.939,-94.939,-93.939,-96.939,-97.939,-97.939,-92.939,-93.939,-97.939,-95.939,-94.939,-93.939,-94.939,-96.939,-95.939,-93.939,-93.939,-98.939,-80.939,-37.939003,-8.939003,7.060997,1.060997,-39.939003,-89.939,-88.939,-86.939,-82.939,-84.939,-84.939,-83.939,-83.939,-84.939,-83.939,-84.939,-86.939,-84.939,-82.939,-75.939,-83.939,-93.939,-95.939,-96.939,-98.939,-97.939,-96.939,-92.939,-94.939,-97.939,-101.939,-87.939,-56.939003,-43.939003,-40.939003,-57.939003,-63.939003,-64.939,-75.939,-81.939,-83.939,-89.939,-93.939,-93.939,-92.939,-92.939,-93.939,-96.939,-98.939,-98.939,-98.939,-100.939,-100.939,-101.939,-102.939,-101.939,-99.939,-100.939,-102.939,-102.939,-103.939,-103.939,-102.939,-100.939,-97.939,-100.939,-100.939,-90.939,-60.939003,-25.939003,-27.939003,-28.939003,-30.939003,-30.939003,-32.939003,-38.939003,-39.939003,-35.939003,-24.939003,-21.939003,-24.939003,-24.939003,-24.939003,-25.939003,-24.939003,-21.939003,-26.939003,-29.939003,-34.939003,-31.939003,-29.939003,-71.939,-71.939,-71.939,-62.939003,-48.939003,-30.939003,-29.939003,-31.939003,-39.939003,-47.939003,-54.939003,-54.939003,-47.939003,-33.939003,-18.939003,-8.939003,-4.939003,-5.939003,-9.939003,-10.939003,-15.939003,-24.939003,-25.939003,-25.939003,-26.939003,-31.939003,-38.939003,-42.939003,-48.939003,-53.939003,-61.939003,-64.939,-63.939003,-64.939,-70.939,-75.939,-83.939,-92.939,-95.939,-92.939,-79.939,-48.939003,-20.939003,-65.939,-91.939,-98.939,-98.939,-99.939,-100.939,-100.939,-100.939,-101.939,-102.939,-101.939,-102.939,-102.939,-102.939,-102.939,-102.939,-100.939,-99.939,-100.939,-101.939,-101.939,-100.939,-101.939,-103.939,-103.939,-102.939,-99.939,-99.939,-98.939,-95.939,-92.939,-87.939,-78.939,-73.939,-72.939,-51.939003,-23.939003,14.060997,14.060997,-1.939003,-4.939003,-13.939003,-27.939003,-25.939003,-20.939003,-18.939003,-11.939003,-3.939003,-4.939003,-2.939003,-1.939003,3.060997,6.060997,5.060997,3.060997,-0.939003,-5.939003,-8.939003,-11.939003,-12.939003,-16.939003,-20.939003,-26.939003,-31.939003,-37.939003,-42.939003,-48.939003,-54.939003,-60.939003,-64.939,-70.939,-77.939,-86.939,-90.939,-91.939,-92.939,-95.939,-96.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-97.939,-100.939,-100.939,-100.939,-100.939,-101.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-103.939,-103.939,-100.939,-90.939,-13.939003,80.061,64.061,65.061,80.061,17.060997,-44.939003,-83.939,-80.939,-64.939,-62.939003,-58.939003,-55.939003,-49.939003,-41.939003,-36.939003,-31.939003,-24.939003,-20.939003,-13.939003,-4.939003,-0.939003,1.060997,-0.939003,0.06099701,2.060997,4.060997,4.060997,3.060997,1.060997,-0.939003,-5.939003,-9.939003,-12.939003,-16.939003,-23.939003,-30.939003,-36.939003,-42.939003,-49.939003,-55.939003,-59.939003,-64.939,-71.939,-79.939,-86.939,-90.939,-91.939,-92.939,-95.939,-96.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-98.939,-99.939,-100.939,-100.939,-71.939,-45.939003,-38.939003,-23.939003,-3.939003,34.060997,56.060997,65.061,63.060997,61.060997,63.060997,64.061,66.061,65.061,62.060997,60.060997,61.060997,47.060997,1.060997,4.060997,24.060997,15.060997,22.060997,44.060997,34.060997,22.060997,11.060997,7.060997,3.060997,-6.939003,-11.939003,-10.939003,-19.939003,-27.939003,-29.939003,-30.939003,-29.939003,-32.939003,-39.939003,-48.939003,-40.939003,-33.939003,-32.939003,-37.939003,-44.939003,-40.939003,-39.939003,-42.939003,-33.939003,-16.939003,13.060997,28.060997,34.060997,26.060997,16.060997,4.060997,-32.939003,-59.939003,-52.939003,-54.939003,-57.939003,-62.939003,-67.939,-73.939,-79.939,-85.939,-86.939,-89.939,-93.939,-95.939,-97.939,-97.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-101.939,-101.939,-101.939,-101.939,-102.939,-102.939,-103.939,-103.939,-101.939,-82.939,-41.939003,21.060997,30.060997,33.060997,37.060997,31.060997,19.060997,-21.939003,-56.939003,-87.939,-72.939,-58.939003,-53.939003,-50.939003,-45.939003,-27.939003,-22.939003,-31.939003,-38.939003,-39.939003,-26.939003,-15.939003,-5.939003,-4.939003,-2.939003,-3.939003,-0.939003,1.060997,1.060997,0.06099701,-2.939003,-5.939003,-6.939003,-6.939003,-8.939003,-10.939003,-17.939003,-20.939003,-21.939003,-23.939003,-29.939003,-40.939003,-44.939003,-48.939003,-51.939003,-55.939003,-59.939003,-66.939,-64.939,-51.939003,-63.939003,-77.939,-91.939,-80.939,-63.939003,-85.939,-96.939,-96.939,-97.939,-97.939,-97.939,-97.939,-98.939,-99.939,-99.939,-99.939,-99.939,-84.939,-33.939003,38.060997,109.061,99.061,65.061,5.060997,-58.939003,-102.939,-102.939,-102.939,-103.939,-101.939,-99.939,-98.939,-90.939,-82.939,-81.939,-80.939,-78.939,-73.939,-69.939,-65.939,-61.939003,-57.939003,-52.939003,-46.939003,-38.939003,-34.939003,-33.939003,-34.939003,-24.939003,-17.939003,-22.939003,-16.939003,-5.939003,-5.939003,-3.939003,1.060997,3.060997,5.060997,5.060997,2.060997,0.06099701,1.060997,-2.939003,-12.939003,-23.939003,-33.939003,-34.939003,-32.939003,-28.939003,-24.939003,-20.939003,-20.939003,-14.939003,-8.939003,-4.939003,1.060997,6.060997,9.060997,14.060997,18.060997,28.060997,27.060997,0.06099701,-7.939003,-4.939003,16.060997,37.060997,60.060997,61.060997,58.060997,51.060997,49.060997,51.060997,48.060997,47.060997,48.060997,44.060997,41.060997,41.060997,42.060997,43.060997,23.060997,18.060997,28.060997,23.060997,19.060997,25.060997,3.060997,-29.939003,-57.939003,-62.939003,-45.939003,-33.939003,-24.939003,-16.939003,-9.939003,-4.939003,-2.939003,-0.939003,0.06099701,10.060997,20.060997,28.060997,34.060997,36.060997,36.060997,37.060997,39.060997,39.060997,36.060997,26.060997,15.060997,4.060997,17.060997,24.060997,23.060997,18.060997,13.060997,11.060997,7.060997,1.060997,-6.939003,-12.939003,-17.939003,-31.939003,-46.939003,-63.939003,-80.939,-95.939,-95.939,-83.939,-60.939003,-72.939,-72.939,-33.939003,7.060997,45.060997,39.060997,12.060997,-36.939003,-47.939003,-50.939003,-40.939003,-36.939003,-35.939003,-37.939003,-35.939003,-32.939003,-26.939003,-20.939003,-15.939003,-12.939003,-11.939003,-4.939003,-2.939003,-3.939003,-4.939003,-4.939003,-2.939003,-2.939003,-4.939003,-8.939003,-12.939003,-12.939003,-14.939003,-16.939003,-19.939003,-22.939003,-25.939003,-30.939003,-32.939003,-31.939003,-34.939003,-37.939003,-39.939003,-41.939003,-43.939003,-52.939003,-60.939003,-64.939,-60.939003,-51.939003,-35.939003,-35.939003,-41.939003,-35.939003,-31.939003,-28.939003,-26.939003,-25.939003,-26.939003,-23.939003,-18.939003,-19.939003,-17.939003,-9.939003,-8.939003,-8.939003,-10.939003,-11.939003,-10.939003,-9.939003,-9.939003,-9.939003,-13.939003,-17.939003,-18.939003,-23.939003,-27.939003,-23.939003,-19.939003,-16.939003,-22.939003,-25.939003,-26.939003,-28.939003,-32.939003,-37.939003,-27.939003,-3.939003,-11.939003,-20.939003,-25.939003,-20.939003,-13.939003,-20.939003,-26.939003,-31.939003,-40.939003,-48.939003,-50.939003,-50.939003,-50.939003,-60.939003,-66.939,-67.939,-65.939,-61.939003,-59.939003,-58.939003,-58.939003,-67.939,-70.939,-68.939,-62.939003,-56.939003,-48.939003,-44.939003,-42.939003,-39.939003,-36.939003,-35.939003,-36.939003,-36.939003,-38.939003,-41.939003,-44.939003,-42.939003,-39.939003,-35.939003,-30.939003,-25.939003,-22.939003,-21.939003,-19.939003,-15.939003,-17.939003,-26.939003,-30.939003,-34.939003,-38.939003,-46.939003,-54.939003,-57.939003,-61.939003,-66.939,-62.939003,-58.939003,-59.939003,-58.939003,-57.939003,-52.939003,-49.939003,-50.939003,-47.939003,-43.939003,-39.939003,-37.939003,-36.939003,-39.939003,-43.939003,-43.939003,-40.939003,-37.939003,-37.939003,-34.939003,-30.939003,-25.939003,-24.939003,-27.939003,-49.939003,-66.939,-68.939,-71.939,-72.939,-64.939,-63.939003,-67.939,-73.939,-76.939,-75.939,-76.939,-78.939,-77.939,-81.939,-88.939,-76.939,-64.939,-63.939003,-79.939,-101.939,-94.939,-91.939,-93.939,-95.939,-97.939,-99.939,-100.939,-102.939,-102.939,-102.939,-102.939,-99.939,-99.939,-101.939,-102.939,-102.939,-103.939,-75.939,-21.939003,27.060997,53.060997,30.060997,-31.939003,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-96.939,-90.939,-95.939,-99.939,-103.939,-92.939,-85.939,-89.939,-92.939,-95.939,-100.939,-82.939,-41.939003,-28.939003,-28.939003,-50.939003,-56.939003,-55.939003,-66.939,-72.939,-74.939,-80.939,-84.939,-82.939,-87.939,-92.939,-90.939,-90.939,-91.939,-94.939,-97.939,-98.939,-98.939,-98.939,-100.939,-100.939,-98.939,-101.939,-103.939,-101.939,-102.939,-103.939,-103.939,-102.939,-101.939,-102.939,-99.939,-88.939,-75.939,-63.939003,-65.939,-63.939003,-59.939003,-57.939003,-57.939003,-58.939003,-58.939003,-56.939003,-52.939003,-52.939003,-55.939003,-55.939003,-54.939003,-50.939003,-52.939003,-56.939003,-59.939003,-60.939003,-61.939003,-59.939003,-58.939003,-57.939003,-46.939003,-32.939003,-23.939003,-15.939003,-6.939003,-9.939003,-11.939003,-10.939003,-9.939003,-7.939003,-5.939003,-10.939003,-22.939003,-37.939003,-47.939003,-45.939003,-46.939003,-49.939003,-51.939003,-54.939003,-60.939003,-64.939,-66.939,-65.939,-68.939,-72.939,-74.939,-77.939,-78.939,-82.939,-84.939,-87.939,-86.939,-87.939,-90.939,-94.939,-99.939,-100.939,-98.939,-90.939,-53.939003,-14.939003,-57.939003,-87.939,-101.939,-102.939,-101.939,-98.939,-97.939,-96.939,-94.939,-93.939,-95.939,-94.939,-93.939,-92.939,-89.939,-86.939,-82.939,-78.939,-75.939,-74.939,-72.939,-66.939,-62.939003,-58.939003,-55.939003,-52.939003,-48.939003,-44.939003,-40.939003,-38.939003,-35.939003,-31.939003,-29.939003,-26.939003,-21.939003,-18.939003,-14.939003,-7.939003,-9.939003,-14.939003,-14.939003,-18.939003,-24.939003,-27.939003,-28.939003,-26.939003,-28.939003,-31.939003,-33.939003,-35.939003,-38.939003,-38.939003,-38.939003,-43.939003,-47.939003,-51.939003,-56.939003,-60.939003,-61.939003,-62.939003,-62.939003,-63.939003,-67.939,-70.939,-73.939,-76.939,-80.939,-83.939,-86.939,-88.939,-91.939,-94.939,-98.939,-100.939,-101.939,-101.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-103.939,-102.939,-102.939,-100.939,-97.939,-96.939,-95.939,-93.939,-91.939,-88.939,-84.939,-79.939,-74.939,-73.939,-69.939,-62.939003,-57.939003,-52.939003,-48.939003,-20.939003,13.060997,4.060997,-0.939003,-0.939003,-17.939003,-31.939003,-32.939003,-29.939003,-26.939003,-24.939003,-21.939003,-20.939003,-20.939003,-20.939003,-24.939003,-24.939003,-22.939003,-24.939003,-27.939003,-31.939003,-30.939003,-30.939003,-35.939003,-37.939003,-40.939003,-43.939003,-46.939003,-50.939003,-53.939003,-54.939003,-57.939003,-59.939003,-62.939003,-64.939,-66.939,-71.939,-73.939,-76.939,-81.939,-83.939,-85.939,-88.939,-91.939,-95.939,-98.939,-100.939,-101.939,-101.939,-101.939,-102.939,-103.939,-103.939,-100.939,-98.939,-96.939,-94.939,-92.939,-91.939,-92.939,-95.939,-67.939,-42.939003,-34.939003,-23.939003,-9.939003,12.060997,25.060997,31.060997,23.060997,15.060997,14.060997,11.060997,9.060997,8.060997,7.060997,8.060997,5.060997,-0.939003,-11.939003,-11.939003,-6.939003,-13.939003,-12.939003,-2.939003,-7.939003,-11.939003,-15.939003,-14.939003,-12.939003,-18.939003,-19.939003,-15.939003,-13.939003,-10.939003,-9.939003,-7.939003,-4.939003,-3.939003,-2.939003,-3.939003,-3.939003,-4.939003,-5.939003,1.060997,9.060997,8.060997,11.060997,17.060997,8.060997,3.060997,8.060997,21.060997,35.060997,29.060997,22.060997,12.060997,-36.939003,-75.939,-80.939,-83.939,-85.939,-87.939,-89.939,-92.939,-95.939,-98.939,-99.939,-100.939,-101.939,-103.939,-103.939,-103.939,-102.939,-101.939,-101.939,-98.939,-96.939,-96.939,-95.939,-92.939,-89.939,-86.939,-82.939,-78.939,-74.939,-73.939,-69.939,-65.939,-64.939,-61.939003,-57.939003,-56.939003,-55.939003,-45.939003,-22.939003,14.060997,26.060997,33.060997,36.060997,30.060997,18.060997,-22.939003,-47.939003,-54.939003,-31.939003,-14.939003,-21.939003,-24.939003,-25.939003,-21.939003,-21.939003,-26.939003,-24.939003,-22.939003,-26.939003,-30.939003,-33.939003,-34.939003,-35.939003,-37.939003,-37.939003,-38.939003,-42.939003,-46.939003,-50.939003,-55.939003,-58.939003,-58.939003,-59.939003,-60.939003,-64.939,-66.939,-66.939,-68.939,-71.939,-76.939,-78.939,-80.939,-82.939,-83.939,-85.939,-89.939,-82.939,-65.939,-77.939,-90.939,-99.939,-89.939,-74.939,-93.939,-103.939,-103.939,-101.939,-98.939,-96.939,-94.939,-93.939,-88.939,-84.939,-81.939,-82.939,-72.939,-39.939003,0.06099701,38.060997,29.060997,11.060997,-15.939003,-40.939003,-57.939003,-55.939003,-51.939003,-47.939003,-44.939003,-43.939003,-43.939003,-36.939003,-30.939003,-28.939003,-28.939003,-29.939003,-27.939003,-26.939003,-28.939003,-24.939003,-20.939003,-19.939003,-18.939003,-17.939003,-18.939003,-21.939003,-26.939003,-23.939003,-22.939003,-29.939003,-31.939003,-32.939003,-34.939003,-35.939003,-36.939003,-35.939003,-34.939003,-39.939003,-44.939003,-50.939003,-51.939003,-53.939003,-53.939003,-41.939003,-31.939003,-29.939003,-39.939003,-49.939003,-24.939003,4.060997,34.060997,34.060997,32.060997,35.060997,38.060997,40.060997,39.060997,38.060997,36.060997,37.060997,30.060997,1.060997,-8.939003,-7.939003,12.060997,29.060997,42.060997,34.060997,27.060997,22.060997,18.060997,15.060997,11.060997,5.060997,0.06099701,-1.939003,-4.939003,-10.939003,-10.939003,-7.939003,-16.939003,-20.939003,-20.939003,-24.939003,-27.939003,-26.939003,-28.939003,-31.939003,-47.939003,-61.939003,-73.939,-55.939003,-39.939003,-36.939003,-40.939003,-46.939003,-41.939003,-26.939003,-1.939003,-11.939003,-20.939003,-16.939003,-10.939003,-5.939003,-3.939003,1.060997,7.060997,1.060997,-7.939003,-19.939003,-14.939003,-4.939003,-10.939003,-12.939003,-12.939003,-19.939003,-23.939003,-13.939003,-16.939003,-26.939003,-39.939003,-48.939003,-55.939003,-61.939003,-69.939,-82.939,-70.939,-49.939003,-36.939003,-28.939003,-25.939003,-28.939003,-27.939003,-8.939003,12.060997,31.060997,21.060997,4.060997,-20.939003,-23.939003,-24.939003,-20.939003,-21.939003,-23.939003,-28.939003,-30.939003,-29.939003,-28.939003,-27.939003,-26.939003,-26.939003,-26.939003,-27.939003,-26.939003,-25.939003,-28.939003,-29.939003,-27.939003,-25.939003,-24.939003,-25.939003,-25.939003,-24.939003,-21.939003,-19.939003,-20.939003,-21.939003,-21.939003,-21.939003,-19.939003,-17.939003,-20.939003,-22.939003,-22.939003,-21.939003,-21.939003,-26.939003,-27.939003,-24.939003,-19.939003,-15.939003,-16.939003,-21.939003,-27.939003,-25.939003,-26.939003,-29.939003,-29.939003,-30.939003,-35.939003,-38.939003,-38.939003,-40.939003,-41.939003,-39.939003,-42.939003,-45.939003,-49.939003,-52.939003,-54.939003,-54.939003,-54.939003,-52.939003,-54.939003,-55.939003,-55.939003,-54.939003,-52.939003,-43.939003,-38.939003,-36.939003,-39.939003,-41.939003,-39.939003,-40.939003,-43.939003,-42.939003,-39.939003,-34.939003,-35.939003,-37.939003,-40.939003,-41.939003,-42.939003,-45.939003,-47.939003,-49.939003,-50.939003,-51.939003,-50.939003,-48.939003,-47.939003,-50.939003,-51.939003,-48.939003,-43.939003,-38.939003,-35.939003,-31.939003,-27.939003,-26.939003,-25.939003,-27.939003,-21.939003,-15.939003,-9.939003,-24.939003,-46.939003,-49.939003,-47.939003,-44.939003,-45.939003,-46.939003,-47.939003,-47.939003,-47.939003,-41.939003,-37.939003,-33.939003,-29.939003,-25.939003,-22.939003,-22.939003,-22.939003,-20.939003,-23.939003,-29.939003,-32.939003,-35.939003,-31.939003,-16.939003,1.060997,-15.939003,-24.939003,-22.939003,-25.939003,-27.939003,-30.939003,-27.939003,-23.939003,-22.939003,-21.939003,-21.939003,-21.939003,-23.939003,-29.939003,-33.939003,-35.939003,-27.939003,-25.939003,-29.939003,-29.939003,-29.939003,-31.939003,-29.939003,-28.939003,-28.939003,-30.939003,-33.939003,-47.939003,-58.939003,-59.939003,-58.939003,-55.939003,-52.939003,-52.939003,-54.939003,-57.939003,-58.939003,-54.939003,-55.939003,-57.939003,-56.939003,-58.939003,-61.939003,-58.939003,-54.939003,-52.939003,-58.939003,-66.939,-65.939,-64.939,-63.939003,-62.939003,-63.939003,-67.939,-69.939,-70.939,-68.939,-67.939,-66.939,-66.939,-66.939,-70.939,-70.939,-69.939,-71.939,-61.939003,-37.939003,-14.939003,-1.939003,-10.939003,-41.939003,-78.939,-78.939,-79.939,-78.939,-82.939,-84.939,-84.939,-85.939,-87.939,-88.939,-89.939,-89.939,-88.939,-87.939,-89.939,-88.939,-87.939,-96.939,-101.939,-103.939,-93.939,-87.939,-92.939,-95.939,-97.939,-101.939,-86.939,-53.939003,-35.939003,-29.939003,-44.939003,-50.939003,-51.939003,-57.939003,-61.939003,-65.939,-68.939,-70.939,-70.939,-75.939,-80.939,-83.939,-85.939,-87.939,-89.939,-91.939,-93.939,-96.939,-99.939,-101.939,-101.939,-99.939,-102.939,-102.939,-99.939,-99.939,-101.939,-102.939,-103.939,-102.939,-100.939,-98.939,-94.939,-90.939,-87.939,-88.939,-87.939,-85.939,-81.939,-76.939,-70.939,-72.939,-79.939,-80.939,-81.939,-79.939,-80.939,-79.939,-74.939,-73.939,-72.939,-56.939003,-46.939003,-41.939003,-41.939003,-42.939003,-25.939003,-8.939003,14.060997,20.060997,23.060997,23.060997,20.060997,20.060997,27.060997,33.060997,39.060997,44.060997,26.060997,-12.939003,-55.939003,-86.939,-84.939,-85.939,-86.939,-88.939,-89.939,-91.939,-96.939,-99.939,-96.939,-95.939,-96.939,-97.939,-97.939,-94.939,-95.939,-97.939,-100.939,-98.939,-93.939,-93.939,-93.939,-94.939,-93.939,-91.939,-88.939,-53.939003,-12.939003,-46.939003,-71.939,-89.939,-88.939,-84.939,-77.939,-75.939,-74.939,-70.939,-68.939,-69.939,-68.939,-67.939,-66.939,-62.939003,-58.939003,-54.939003,-51.939003,-46.939003,-44.939003,-42.939003,-34.939003,-28.939003,-24.939003,-21.939003,-18.939003,-15.939003,-11.939003,-6.939003,-6.939003,-6.939003,-5.939003,-9.939003,-9.939003,-4.939003,-10.939003,-12.939003,-3.939003,-2.939003,-3.939003,1.060997,-3.939003,-19.939003,-37.939003,-50.939003,-49.939003,-57.939003,-66.939,-69.939,-73.939,-78.939,-80.939,-83.939,-88.939,-92.939,-96.939,-100.939,-102.939,-102.939,-102.939,-100.939,-98.939,-98.939,-99.939,-99.939,-100.939,-101.939,-101.939,-101.939,-101.939,-100.939,-98.939,-97.939,-96.939,-95.939,-94.939,-93.939,-92.939,-91.939,-90.939,-87.939,-85.939,-86.939,-85.939,-84.939,-81.939,-78.939,-75.939,-72.939,-70.939,-67.939,-63.939003,-59.939003,-54.939003,-47.939003,-42.939003,-41.939003,-38.939003,-30.939003,-24.939003,-19.939003,-20.939003,-22.939003,-24.939003,-27.939003,-32.939003,-38.939003,-30.939003,-21.939003,-10.939003,-12.939003,-17.939003,-16.939003,-15.939003,-16.939003,-19.939003,-24.939003,-33.939003,-37.939003,-40.939003,-45.939003,-53.939003,-66.939,-67.939,-69.939,-74.939,-78.939,-83.939,-88.939,-93.939,-97.939,-100.939,-100.939,-99.939,-101.939,-103.939,-101.939,-100.939,-101.939,-100.939,-100.939,-101.939,-100.939,-99.939,-98.939,-97.939,-96.939,-95.939,-94.939,-93.939,-91.939,-89.939,-88.939,-88.939,-86.939,-81.939,-77.939,-74.939,-70.939,-68.939,-65.939,-65.939,-69.939,-49.939003,-32.939003,-27.939003,-22.939003,-15.939003,-9.939003,-5.939003,-2.939003,-11.939003,-18.939003,-20.939003,-23.939003,-26.939003,-27.939003,-25.939003,-20.939003,-23.939003,-24.939003,-18.939003,-17.939003,-19.939003,-23.939003,-22.939003,-16.939003,-15.939003,-16.939003,-16.939003,-12.939003,-5.939003,-6.939003,-3.939003,1.060997,10.060997,18.060997,22.060997,26.060997,29.060997,33.060997,38.060997,45.060997,35.060997,25.060997,20.060997,37.060997,60.060997,53.060997,56.060997,69.061,44.060997,20.060997,4.060997,15.060997,37.060997,32.060997,26.060997,19.060997,-36.939003,-83.939,-94.939,-97.939,-95.939,-95.939,-95.939,-94.939,-92.939,-91.939,-90.939,-89.939,-87.939,-87.939,-86.939,-84.939,-83.939,-81.939,-79.939,-75.939,-73.939,-71.939,-69.939,-64.939,-60.939003,-56.939003,-50.939003,-46.939003,-42.939003,-40.939003,-35.939003,-30.939003,-29.939003,-27.939003,-23.939003,-24.939003,-25.939003,-23.939003,-12.939003,8.060997,22.060997,33.060997,36.060997,31.060997,19.060997,-24.939003,-45.939003,-42.939003,-19.939003,-5.939003,-18.939003,-25.939003,-27.939003,-9.939003,-1.939003,-2.939003,9.060997,7.060997,-30.939003,-53.939003,-68.939,-70.939,-73.939,-75.939,-77.939,-79.939,-85.939,-89.939,-93.939,-99.939,-101.939,-102.939,-101.939,-101.939,-101.939,-101.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-95.939,-94.939,-93.939,-91.939,-82.939,-67.939,-74.939,-82.939,-85.939,-78.939,-67.939,-79.939,-84.939,-82.939,-80.939,-77.939,-72.939,-69.939,-66.939,-60.939003,-55.939003,-51.939003,-53.939003,-51.939003,-38.939003,-29.939003,-24.939003,-28.939003,-29.939003,-30.939003,-27.939003,-25.939003,-21.939003,-16.939003,-12.939003,-10.939003,-10.939003,-11.939003,-9.939003,-6.939003,-6.939003,-8.939003,-11.939003,-12.939003,-14.939003,-19.939003,-17.939003,-14.939003,-15.939003,-18.939003,-23.939003,-27.939003,-32.939003,-39.939003,-40.939003,-44.939003,-50.939003,-58.939003,-66.939,-69.939,-72.939,-76.939,-76.939,-77.939,-83.939,-88.939,-94.939,-98.939,-96.939,-87.939,-54.939003,-29.939003,-28.939003,-44.939003,-62.939003,-27.939003,12.060997,59.060997,52.060997,44.060997,45.060997,43.060997,39.060997,37.060997,32.060997,26.060997,21.060997,11.060997,-10.939003,-18.939003,-19.939003,-5.939003,3.060997,5.060997,-5.939003,-12.939003,-13.939003,-16.939003,-21.939003,-25.939003,-31.939003,-40.939003,-39.939003,-40.939003,-46.939003,-45.939003,-43.939003,-42.939003,-44.939003,-47.939003,-49.939003,-50.939003,-52.939003,-39.939003,-23.939003,-35.939003,-53.939003,-78.939,-60.939003,-45.939003,-48.939003,-60.939003,-73.939,-68.939,-46.939003,-9.939003,-33.939003,-55.939003,-56.939003,-52.939003,-45.939003,-43.939003,-35.939003,-21.939003,-33.939003,-45.939003,-55.939003,-38.939003,-15.939003,-39.939003,-49.939003,-46.939003,-54.939003,-56.939003,-39.939003,-42.939003,-54.939003,-67.939,-78.939,-85.939,-84.939,-85.939,-91.939,-59.939003,-13.939003,1.060997,4.060997,-4.939003,-1.939003,-0.939003,-2.939003,-0.939003,-0.939003,-10.939003,-15.939003,-14.939003,-11.939003,-9.939003,-10.939003,-13.939003,-16.939003,-23.939003,-26.939003,-25.939003,-28.939003,-31.939003,-33.939003,-34.939003,-35.939003,-40.939003,-41.939003,-39.939003,-42.939003,-44.939003,-43.939003,-42.939003,-40.939003,-39.939003,-37.939003,-35.939003,-30.939003,-27.939003,-27.939003,-26.939003,-25.939003,-22.939003,-20.939003,-18.939003,-20.939003,-21.939003,-21.939003,-21.939003,-19.939003,-5.939003,6.060997,17.060997,15.060997,5.060997,-16.939003,-25.939003,-27.939003,-30.939003,-34.939003,-39.939003,-40.939003,-41.939003,-47.939003,-52.939003,-56.939003,-57.939003,-59.939003,-60.939003,-64.939,-69.939,-72.939,-75.939,-78.939,-79.939,-78.939,-75.939,-75.939,-74.939,-73.939,-69.939,-64.939,-56.939003,-51.939003,-51.939003,-52.939003,-51.939003,-48.939003,-47.939003,-48.939003,-43.939003,-43.939003,-48.939003,-45.939003,-44.939003,-44.939003,-48.939003,-52.939003,-50.939003,-49.939003,-46.939003,-45.939003,-44.939003,-43.939003,-35.939003,-25.939003,-24.939003,-23.939003,-20.939003,-15.939003,-10.939003,-9.939003,-5.939003,0.06099701,8.060997,9.060997,3.060997,8.060997,13.060997,14.060997,-12.939003,-48.939003,-53.939003,-52.939003,-48.939003,-48.939003,-48.939003,-48.939003,-47.939003,-46.939003,-40.939003,-35.939003,-32.939003,-29.939003,-25.939003,-22.939003,-23.939003,-25.939003,-25.939003,-27.939003,-29.939003,-34.939003,-35.939003,-24.939003,10.060997,49.060997,21.060997,10.060997,17.060997,9.060997,2.060997,-1.939003,3.060997,10.060997,10.060997,10.060997,10.060997,9.060997,2.060997,-14.939003,-25.939003,-30.939003,-5.939003,4.060997,-0.939003,-3.939003,-5.939003,-6.939003,-7.939003,-10.939003,-8.939003,-15.939003,-28.939003,-49.939003,-62.939003,-62.939003,-59.939003,-55.939003,-56.939003,-57.939003,-57.939003,-57.939003,-56.939003,-52.939003,-52.939003,-53.939003,-53.939003,-53.939003,-52.939003,-51.939003,-50.939003,-51.939003,-51.939003,-51.939003,-54.939003,-54.939003,-51.939003,-49.939003,-49.939003,-52.939003,-53.939003,-53.939003,-50.939003,-48.939003,-47.939003,-48.939003,-49.939003,-51.939003,-51.939003,-49.939003,-54.939003,-52.939003,-46.939003,-42.939003,-40.939003,-41.939003,-50.939003,-61.939003,-60.939003,-61.939003,-60.939003,-66.939,-69.939,-69.939,-70.939,-72.939,-73.939,-74.939,-73.939,-71.939,-70.939,-72.939,-76.939,-83.939,-95.939,-102.939,-103.939,-96.939,-90.939,-93.939,-96.939,-97.939,-101.939,-89.939,-61.939003,-40.939003,-30.939003,-40.939003,-46.939003,-49.939003,-50.939003,-53.939003,-57.939003,-58.939003,-59.939003,-60.939003,-63.939003,-66.939,-73.939,-77.939,-79.939,-81.939,-83.939,-86.939,-91.939,-96.939,-100.939,-100.939,-99.939,-100.939,-100.939,-97.939,-98.939,-99.939,-101.939,-102.939,-102.939,-99.939,-97.939,-97.939,-98.939,-99.939,-99.939,-100.939,-100.939,-95.939,-87.939,-74.939,-80.939,-93.939,-98.939,-99.939,-95.939,-97.939,-97.939,-93.939,-88.939,-81.939,-50.939003,-30.939003,-22.939003,-22.939003,-23.939003,45.060997,54.060997,65.061,65.061,64.061,64.061,61.060997,60.060997,61.060997,58.060997,54.060997,64.061,45.060997,-3.939003,-59.939003,-99.939,-99.939,-97.939,-95.939,-97.939,-98.939,-100.939,-101.939,-100.939,-95.939,-94.939,-92.939,-93.939,-92.939,-89.939,-94.939,-96.939,-95.939,-89.939,-82.939,-78.939,-76.939,-74.939,-71.939,-67.939,-59.939003,-42.939003,-23.939003,-34.939003,-43.939003,-48.939003,-41.939003,-33.939003,-22.939003,-19.939003,-16.939003,-11.939003,-8.939003,-4.939003,-5.939003,-6.939003,-7.939003,-7.939003,-7.939003,-9.939003,-12.939003,-15.939003,-14.939003,-13.939003,-13.939003,-19.939003,-26.939003,-29.939003,-32.939003,-35.939003,-39.939003,-43.939003,-50.939003,-55.939003,-60.939003,-62.939003,-67.939,-76.939,-65.939,-31.939003,51.060997,67.061,57.060997,70.061,48.060997,-5.939003,-55.939003,-91.939,-91.939,-92.939,-95.939,-95.939,-96.939,-97.939,-98.939,-98.939,-100.939,-101.939,-101.939,-101.939,-101.939,-100.939,-100.939,-101.939,-100.939,-97.939,-94.939,-96.939,-97.939,-96.939,-96.939,-95.939,-92.939,-88.939,-83.939,-79.939,-74.939,-69.939,-63.939003,-59.939003,-55.939003,-52.939003,-49.939003,-38.939003,-32.939003,-31.939003,-26.939003,-19.939003,-13.939003,-13.939003,-14.939003,-13.939003,-10.939003,-6.939003,-4.939003,-2.939003,-2.939003,-6.939003,-10.939003,-13.939003,-19.939003,-27.939003,-31.939003,-34.939003,-35.939003,-14.939003,13.060997,18.060997,27.060997,42.060997,16.060997,-19.939003,-66.939,-81.939,-81.939,-82.939,-82.939,-81.939,-81.939,-82.939,-86.939,-88.939,-88.939,-90.939,-92.939,-94.939,-94.939,-94.939,-96.939,-97.939,-98.939,-100.939,-101.939,-102.939,-102.939,-102.939,-101.939,-100.939,-101.939,-100.939,-99.939,-100.939,-99.939,-97.939,-93.939,-90.939,-86.939,-82.939,-77.939,-73.939,-68.939,-64.939,-59.939003,-53.939003,-47.939003,-42.939003,-37.939003,-31.939003,-25.939003,-20.939003,-19.939003,-15.939003,-11.939003,-9.939003,-7.939003,-5.939003,-5.939003,-8.939003,-16.939003,-18.939003,-17.939003,-22.939003,-24.939003,-22.939003,-20.939003,-17.939003,-9.939003,-5.939003,-1.939003,1.060997,6.060997,14.060997,20.060997,17.060997,-8.939003,-2.939003,15.060997,12.060997,25.060997,54.060997,55.060997,51.060997,41.060997,42.060997,47.060997,52.060997,55.060997,57.060997,58.060997,59.060997,61.060997,61.060997,60.060997,62.060997,67.061,72.061,58.060997,43.060997,27.060997,47.060997,78.061,65.061,64.061,79.061,49.060997,20.060997,2.060997,16.060997,39.060997,31.060997,25.060997,22.060997,-26.939003,-68.939,-79.939,-77.939,-69.939,-70.939,-68.939,-63.939003,-56.939003,-49.939003,-49.939003,-44.939003,-39.939003,-35.939003,-30.939003,-23.939003,-21.939003,-18.939003,-12.939003,-10.939003,-11.939003,-8.939003,-4.939003,-1.939003,0.06099701,0.06099701,-2.939003,-6.939003,-11.939003,-8.939003,-9.939003,-12.939003,-15.939003,-19.939003,-27.939003,-35.939003,-43.939003,-43.939003,-26.939003,8.060997,23.060997,32.060997,37.060997,33.060997,22.060997,-25.939003,-61.939003,-87.939,-84.939,-80.939,-83.939,-85.939,-78.939,5.060997,48.060997,51.060997,64.061,48.060997,-35.939003,-75.939,-95.939,-95.939,-96.939,-96.939,-96.939,-97.939,-97.939,-98.939,-100.939,-102.939,-103.939,-103.939,-103.939,-101.939,-96.939,-92.939,-89.939,-87.939,-83.939,-79.939,-79.939,-75.939,-66.939,-62.939003,-60.939003,-52.939003,-47.939003,-45.939003,-40.939003,-35.939003,-31.939003,-29.939003,-29.939003,-26.939003,-20.939003,-11.939003,-14.939003,-16.939003,-14.939003,-9.939003,-3.939003,-5.939003,-6.939003,-7.939003,-9.939003,-12.939003,-18.939003,-26.939003,-33.939003,-26.939003,-24.939003,-28.939003,-32.939003,-34.939003,-33.939003,-38.939003,-44.939003,-44.939003,-45.939003,-50.939003,-54.939003,-60.939003,-68.939,-72.939,-75.939,-77.939,-79.939,-84.939,-83.939,-83.939,-81.939,-83.939,-84.939,-85.939,-87.939,-88.939,-89.939,-89.939,-91.939,-93.939,-95.939,-95.939,-96.939,-97.939,-97.939,-97.939,-98.939,-99.939,-99.939,-101.939,-96.939,-82.939,-50.939003,-27.939003,-33.939003,-43.939003,-53.939003,-39.939003,-21.939003,0.06099701,-5.939003,-13.939003,-15.939003,-25.939003,-36.939003,-36.939003,-36.939003,-38.939003,-42.939003,-44.939003,-42.939003,-41.939003,-41.939003,-44.939003,-47.939003,-50.939003,-50.939003,-47.939003,-41.939003,-39.939003,-40.939003,-40.939003,-40.939003,-38.939003,-35.939003,-32.939003,-27.939003,-24.939003,-21.939003,-24.939003,-20.939003,-12.939003,-11.939003,-9.939003,-0.939003,7.060997,11.060997,-26.939003,-41.939003,-33.939003,-29.939003,-28.939003,-36.939003,-44.939003,-52.939003,-49.939003,-41.939003,-24.939003,-39.939003,-54.939003,-60.939003,-57.939003,-52.939003,-58.939003,-49.939003,-22.939003,-34.939003,-44.939003,-37.939003,-26.939003,-16.939003,-48.939003,-61.939003,-51.939003,-55.939003,-55.939003,-47.939003,-51.939003,-59.939003,-69.939,-75.939,-78.939,-77.939,-77.939,-76.939,-52.939003,-21.939003,-25.939003,-28.939003,-30.939003,-32.939003,-36.939003,-43.939003,-48.939003,-51.939003,-53.939003,-47.939003,-33.939003,-29.939003,-27.939003,-23.939003,-23.939003,-24.939003,-21.939003,-18.939003,-14.939003,-15.939003,-16.939003,-14.939003,-14.939003,-14.939003,-13.939003,-14.939003,-17.939003,-16.939003,-16.939003,-19.939003,-23.939003,-28.939003,-30.939003,-32.939003,-36.939003,-38.939003,-40.939003,-42.939003,-45.939003,-47.939003,-49.939003,-52.939003,-56.939003,-58.939003,-60.939003,-66.939,-72.939,-72.939,-15.939003,20.060997,35.060997,10.060997,-26.939003,-70.939,-79.939,-72.939,-72.939,-72.939,-73.939,-69.939,-66.939,-63.939003,-60.939003,-57.939003,-56.939003,-52.939003,-48.939003,-46.939003,-44.939003,-43.939003,-40.939003,-39.939003,-37.939003,-35.939003,-35.939003,-34.939003,-34.939003,-34.939003,-36.939003,-39.939003,-43.939003,-46.939003,-49.939003,-47.939003,-46.939003,-43.939003,-42.939003,-43.939003,-38.939003,-28.939003,-11.939003,-17.939003,-22.939003,-22.939003,-17.939003,-10.939003,-8.939003,-5.939003,-1.939003,-12.939003,-22.939003,-22.939003,-0.939003,26.060997,20.060997,13.060997,8.060997,8.060997,7.060997,1.060997,1.060997,2.060997,3.060997,-0.939003,-8.939003,-5.939003,-4.939003,-13.939003,-27.939003,-43.939003,-42.939003,-40.939003,-38.939003,-33.939003,-29.939003,-27.939003,-32.939003,-38.939003,-38.939003,-36.939003,-31.939003,-28.939003,-25.939003,-23.939003,-24.939003,-25.939003,-25.939003,-25.939003,-25.939003,-34.939003,-37.939003,-21.939003,11.060997,44.060997,20.060997,12.060997,21.060997,14.060997,7.060997,8.060997,14.060997,22.060997,27.060997,28.060997,28.060997,28.060997,21.060997,-1.939003,-15.939003,-20.939003,24.060997,44.060997,38.060997,39.060997,42.060997,45.060997,38.060997,30.060997,46.060997,33.060997,-7.939003,-57.939003,-94.939,-94.939,-93.939,-92.939,-92.939,-92.939,-92.939,-92.939,-92.939,-91.939,-91.939,-91.939,-91.939,-91.939,-91.939,-74.939,-64.939,-76.939,-85.939,-91.939,-92.939,-91.939,-89.939,-90.939,-90.939,-87.939,-84.939,-81.939,-79.939,-79.939,-78.939,-76.939,-75.939,-73.939,-72.939,-74.939,-75.939,-60.939003,-28.939003,-15.939003,-12.939003,-28.939003,-51.939003,-73.939,-65.939,-63.939003,-65.939,-68.939,-69.939,-67.939,-65.939,-64.939,-62.939003,-59.939003,-57.939003,-56.939003,-55.939003,-51.939003,-61.939003,-75.939,-94.939,-103.939,-103.939,-99.939,-94.939,-89.939,-91.939,-96.939,-100.939,-87.939,-55.939003,-37.939003,-28.939003,-39.939003,-45.939003,-48.939003,-48.939003,-49.939003,-52.939003,-54.939003,-54.939003,-54.939003,-52.939003,-51.939003,-57.939003,-59.939003,-59.939003,-67.939,-74.939,-78.939,-82.939,-86.939,-92.939,-95.939,-95.939,-95.939,-95.939,-99.939,-100.939,-101.939,-98.939,-98.939,-100.939,-98.939,-95.939,-92.939,-89.939,-87.939,-87.939,-88.939,-91.939,-89.939,-82.939,-66.939,-72.939,-88.939,-92.939,-94.939,-94.939,-95.939,-96.939,-98.939,-92.939,-79.939,-43.939003,-22.939003,-13.939003,-10.939003,-8.939003,67.061,70.061,74.061,72.061,73.061,76.061,74.061,73.061,70.061,66.061,63.060997,71.061,52.060997,5.060997,-53.939003,-96.939,-96.939,-94.939,-90.939,-90.939,-89.939,-88.939,-84.939,-80.939,-75.939,-69.939,-62.939003,-63.939003,-61.939003,-57.939003,-56.939003,-56.939003,-57.939003,-51.939003,-44.939003,-40.939003,-38.939003,-36.939003,-33.939003,-31.939003,-30.939003,-28.939003,-25.939003,-26.939003,-26.939003,-27.939003,-23.939003,-20.939003,-17.939003,-17.939003,-17.939003,-16.939003,-17.939003,-19.939003,-24.939003,-29.939003,-32.939003,-34.939003,-34.939003,-38.939003,-43.939003,-46.939003,-46.939003,-46.939003,-48.939003,-53.939003,-59.939003,-61.939003,-63.939003,-66.939,-70.939,-73.939,-78.939,-82.939,-85.939,-87.939,-91.939,-97.939,-86.939,-44.939003,65.061,87.061,74.061,90.061,69.061,11.060997,-55.939003,-103.939,-102.939,-101.939,-100.939,-99.939,-97.939,-95.939,-95.939,-95.939,-91.939,-88.939,-86.939,-83.939,-80.939,-74.939,-71.939,-69.939,-67.939,-62.939003,-56.939003,-55.939003,-54.939003,-53.939003,-50.939003,-47.939003,-45.939003,-43.939003,-40.939003,-38.939003,-35.939003,-32.939003,-29.939003,-28.939003,-26.939003,-26.939003,-25.939003,-20.939003,-20.939003,-25.939003,-23.939003,-21.939003,-19.939003,-24.939003,-30.939003,-31.939003,-30.939003,-29.939003,-30.939003,-31.939003,-35.939003,-40.939003,-46.939003,-47.939003,-52.939003,-59.939003,-63.939003,-66.939,-66.939,-20.939003,40.060997,53.060997,67.061,84.061,46.060997,-7.939003,-80.939,-102.939,-101.939,-102.939,-102.939,-102.939,-100.939,-100.939,-102.939,-103.939,-103.939,-103.939,-102.939,-99.939,-97.939,-96.939,-95.939,-94.939,-92.939,-89.939,-85.939,-83.939,-79.939,-75.939,-71.939,-68.939,-67.939,-62.939003,-59.939003,-56.939003,-53.939003,-51.939003,-47.939003,-45.939003,-42.939003,-38.939003,-34.939003,-31.939003,-29.939003,-27.939003,-25.939003,-24.939003,-23.939003,-21.939003,-19.939003,-18.939003,-16.939003,-16.939003,-18.939003,-19.939003,-20.939003,-21.939003,-24.939003,-28.939003,-23.939003,-21.939003,-23.939003,-20.939003,-14.939003,-4.939003,6.060997,18.060997,17.060997,18.060997,26.060997,30.060997,33.060997,38.060997,40.060997,40.060997,53.060997,49.060997,0.06099701,8.060997,37.060997,31.060997,46.060997,83.061,84.061,76.061,61.060997,59.060997,64.061,66.061,67.061,67.061,62.060997,60.060997,61.060997,59.060997,55.060997,54.060997,53.060997,52.060997,44.060997,31.060997,12.060997,22.060997,42.060997,33.060997,28.060997,30.060997,12.060997,1.060997,6.060997,21.060997,38.060997,31.060997,25.060997,20.060997,-23.939003,-53.939003,-39.939003,-32.939003,-29.939003,-30.939003,-31.939003,-29.939003,-27.939003,-24.939003,-25.939003,-23.939003,-21.939003,-21.939003,-21.939003,-21.939003,-21.939003,-20.939003,-21.939003,-24.939003,-28.939003,-27.939003,-27.939003,-28.939003,-30.939003,-32.939003,-37.939003,-42.939003,-46.939003,-44.939003,-46.939003,-48.939003,-51.939003,-55.939003,-61.939003,-67.939,-73.939,-72.939,-44.939003,13.060997,27.060997,34.060997,36.060997,32.060997,23.060997,-20.939003,-61.939003,-100.939,-102.939,-101.939,-100.939,-102.939,-94.939,-1.939003,46.060997,47.060997,56.060997,40.060997,-38.939003,-75.939,-92.939,-85.939,-81.939,-78.939,-76.939,-74.939,-70.939,-67.939,-65.939,-64.939,-62.939003,-60.939003,-60.939003,-58.939003,-51.939003,-47.939003,-44.939003,-41.939003,-39.939003,-37.939003,-35.939003,-33.939003,-30.939003,-28.939003,-25.939003,-25.939003,-24.939003,-23.939003,-20.939003,-18.939003,-18.939003,-20.939003,-22.939003,-25.939003,-24.939003,-19.939003,-24.939003,-29.939003,-30.939003,-29.939003,-26.939003,-29.939003,-31.939003,-33.939003,-39.939003,-39.939003,-27.939003,-2.939003,25.060997,29.060997,18.060997,-5.939003,-40.939003,-66.939,-66.939,-70.939,-74.939,-75.939,-77.939,-79.939,-83.939,-88.939,-93.939,-96.939,-98.939,-99.939,-99.939,-99.939,-99.939,-99.939,-94.939,-93.939,-92.939,-91.939,-90.939,-88.939,-85.939,-82.939,-77.939,-75.939,-74.939,-71.939,-68.939,-64.939,-64.939,-63.939003,-61.939003,-59.939003,-56.939003,-56.939003,-52.939003,-45.939003,-33.939003,-25.939003,-29.939003,-34.939003,-39.939003,-37.939003,-33.939003,-28.939003,-33.939003,-36.939003,-34.939003,-38.939003,-45.939003,-43.939003,-41.939003,-41.939003,-39.939003,-37.939003,-36.939003,-34.939003,-33.939003,-30.939003,-25.939003,-19.939003,-18.939003,-16.939003,-12.939003,-8.939003,-5.939003,-2.939003,-0.939003,3.060997,4.060997,6.060997,10.060997,12.060997,11.060997,6.060997,8.060997,18.060997,21.060997,21.060997,19.060997,7.060997,-9.939003,-35.939003,-38.939003,-18.939003,-12.939003,-11.939003,-12.939003,-15.939003,-18.939003,-16.939003,-14.939003,-11.939003,-16.939003,-20.939003,-22.939003,-20.939003,-18.939003,-25.939003,-15.939003,12.060997,1.060997,-11.939003,-16.939003,-14.939003,-11.939003,-24.939003,-30.939003,-26.939003,-28.939003,-28.939003,-26.939003,-29.939003,-34.939003,-42.939003,-47.939003,-51.939003,-54.939003,-60.939003,-67.939,-54.939003,-34.939003,-30.939003,-28.939003,-27.939003,-30.939003,-34.939003,-34.939003,-30.939003,-25.939003,-25.939003,-27.939003,-30.939003,-32.939003,-34.939003,-32.939003,-34.939003,-37.939003,-38.939003,-38.939003,-40.939003,-40.939003,-41.939003,-42.939003,-43.939003,-44.939003,-43.939003,-45.939003,-50.939003,-50.939003,-50.939003,-52.939003,-55.939003,-58.939003,-59.939003,-60.939003,-61.939003,-62.939003,-64.939,-64.939,-63.939003,-61.939003,-61.939003,-62.939003,-64.939,-62.939003,-60.939003,-64.939,-66.939,-64.939,-30.939003,-11.939003,-9.939003,-20.939003,-36.939003,-57.939003,-58.939003,-51.939003,-53.939003,-54.939003,-55.939003,-54.939003,-53.939003,-50.939003,-47.939003,-45.939003,-50.939003,-52.939003,-50.939003,-49.939003,-50.939003,-51.939003,-52.939003,-54.939003,-52.939003,-51.939003,-55.939003,-56.939003,-57.939003,-57.939003,-60.939003,-61.939003,-49.939003,-44.939003,-47.939003,-46.939003,-44.939003,-41.939003,-45.939003,-51.939003,-40.939003,-19.939003,12.060997,-5.939003,-20.939003,-17.939003,-11.939003,-5.939003,-7.939003,-7.939003,-7.939003,-17.939003,-26.939003,-26.939003,-15.939003,-2.939003,-10.939003,-15.939003,-19.939003,-20.939003,-21.939003,-26.939003,-29.939003,-31.939003,-32.939003,-34.939003,-36.939003,-35.939003,-34.939003,-35.939003,-35.939003,-36.939003,-35.939003,-34.939003,-32.939003,-29.939003,-26.939003,-24.939003,-28.939003,-35.939003,-36.939003,-35.939003,-31.939003,-28.939003,-24.939003,-22.939003,-22.939003,-23.939003,-22.939003,-23.939003,-24.939003,-31.939003,-35.939003,-30.939003,-16.939003,-1.939003,-11.939003,-17.939003,-16.939003,-19.939003,-21.939003,-18.939003,-15.939003,-12.939003,-8.939003,-7.939003,-8.939003,-7.939003,-9.939003,-21.939003,-28.939003,-31.939003,-4.939003,9.060997,8.060997,9.060997,12.060997,15.060997,11.060997,7.060997,18.060997,10.060997,-18.939003,-56.939003,-83.939,-84.939,-85.939,-86.939,-89.939,-89.939,-89.939,-90.939,-91.939,-91.939,-92.939,-93.939,-93.939,-94.939,-96.939,-75.939,-64.939,-81.939,-93.939,-102.939,-100.939,-100.939,-101.939,-100.939,-100.939,-98.939,-96.939,-94.939,-94.939,-95.939,-94.939,-93.939,-92.939,-89.939,-89.939,-90.939,-91.939,-64.939,-9.939003,14.060997,20.060997,-11.939003,-51.939003,-88.939,-83.939,-81.939,-83.939,-84.939,-84.939,-83.939,-81.939,-80.939,-78.939,-76.939,-74.939,-73.939,-71.939,-70.939,-76.939,-84.939,-91.939,-96.939,-99.939,-97.939,-94.939,-91.939,-93.939,-97.939,-100.939,-86.939,-52.939003,-38.939003,-32.939003,-42.939003,-46.939003,-47.939003,-46.939003,-47.939003,-48.939003,-48.939003,-48.939003,-49.939003,-48.939003,-47.939003,-52.939003,-53.939003,-50.939003,-58.939003,-65.939,-70.939,-73.939,-74.939,-79.939,-83.939,-89.939,-86.939,-87.939,-93.939,-97.939,-99.939,-95.939,-96.939,-100.939,-100.939,-95.939,-81.939,-73.939,-68.939,-71.939,-72.939,-72.939,-69.939,-64.939,-58.939003,-61.939003,-69.939,-70.939,-70.939,-70.939,-72.939,-74.939,-77.939,-73.939,-64.939,-44.939003,-33.939003,-30.939003,-28.939003,-26.939003,66.061,64.061,63.060997,60.060997,63.060997,70.061,69.061,68.061,63.060997,60.060997,60.060997,64.061,47.060997,9.060997,-44.939003,-83.939,-82.939,-79.939,-74.939,-74.939,-70.939,-64.939,-57.939003,-51.939003,-46.939003,-36.939003,-24.939003,-25.939003,-22.939003,-18.939003,-11.939003,-7.939003,-12.939003,-8.939003,-3.939003,-2.939003,-0.939003,0.06099701,0.06099701,-0.939003,-9.939003,-17.939003,-24.939003,-20.939003,-19.939003,-21.939003,-24.939003,-29.939003,-36.939003,-41.939003,-45.939003,-49.939003,-55.939003,-65.939,-76.939,-83.939,-88.939,-90.939,-89.939,-93.939,-96.939,-97.939,-98.939,-99.939,-101.939,-101.939,-100.939,-99.939,-98.939,-99.939,-98.939,-97.939,-98.939,-96.939,-95.939,-96.939,-97.939,-97.939,-89.939,-52.939003,57.060997,79.061,67.061,84.061,69.061,23.060997,-47.939003,-98.939,-95.939,-93.939,-91.939,-88.939,-84.939,-78.939,-79.939,-78.939,-70.939,-64.939,-59.939003,-55.939003,-49.939003,-39.939003,-33.939003,-27.939003,-25.939003,-19.939003,-11.939003,-7.939003,-4.939003,-4.939003,1.060997,5.060997,5.060997,5.060997,4.060997,1.060997,-0.939003,-2.939003,-4.939003,-6.939003,-9.939003,-13.939003,-18.939003,-22.939003,-31.939003,-42.939003,-47.939003,-52.939003,-57.939003,-65.939,-75.939,-79.939,-81.939,-84.939,-88.939,-92.939,-96.939,-100.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-99.939,-31.939003,57.060997,78.061,94.061,105.061,64.061,6.060997,-75.939,-99.939,-99.939,-99.939,-99.939,-98.939,-96.939,-96.939,-96.939,-97.939,-97.939,-96.939,-93.939,-86.939,-83.939,-81.939,-78.939,-76.939,-73.939,-65.939,-59.939003,-53.939003,-46.939003,-38.939003,-31.939003,-27.939003,-23.939003,-16.939003,-10.939003,-4.939003,-0.939003,1.060997,3.060997,3.060997,3.060997,4.060997,6.060997,6.060997,3.060997,0.06099701,-1.939003,-6.939003,-13.939003,-18.939003,-21.939003,-26.939003,-32.939003,-37.939003,-41.939003,-49.939003,-57.939003,-62.939003,-70.939,-84.939,-65.939,-47.939003,-36.939003,-25.939003,-11.939003,24.060997,54.060997,79.061,69.061,61.060997,66.061,67.061,67.061,71.061,69.061,59.060997,77.061,72.061,6.060997,14.060997,49.060997,38.060997,50.060997,87.061,86.061,77.061,59.060997,56.060997,59.060997,57.060997,54.060997,50.060997,44.060997,40.060997,41.060997,37.060997,31.060997,28.060997,21.060997,12.060997,13.060997,8.060997,-11.939003,-14.939003,-10.939003,-12.939003,-20.939003,-34.939003,-33.939003,-22.939003,11.060997,27.060997,35.060997,32.060997,25.060997,16.060997,-24.939003,-42.939003,0.06099701,10.060997,5.060997,2.060997,-1.939003,-4.939003,-9.939003,-14.939003,-17.939003,-20.939003,-22.939003,-28.939003,-36.939003,-47.939003,-49.939003,-51.939003,-62.939003,-68.939,-74.939,-79.939,-83.939,-88.939,-93.939,-97.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-102.939,-100.939,-60.939003,18.060997,33.060997,37.060997,34.060997,30.060997,22.060997,-14.939003,-54.939003,-96.939,-96.939,-92.939,-91.939,-94.939,-89.939,-16.939003,18.060997,16.060997,19.060997,8.060997,-38.939003,-61.939003,-73.939,-59.939003,-48.939003,-44.939003,-39.939003,-35.939003,-29.939003,-24.939003,-19.939003,-14.939003,-11.939003,-8.939003,-8.939003,-7.939003,0.06099701,3.060997,5.060997,7.060997,6.060997,3.060997,6.060997,6.060997,-1.939003,-2.939003,-1.939003,-11.939003,-13.939003,-8.939003,-12.939003,-19.939003,-28.939003,-30.939003,-32.939003,-47.939003,-56.939003,-59.939003,-65.939,-72.939,-76.939,-79.939,-81.939,-82.939,-83.939,-85.939,-94.939,-89.939,-47.939003,24.060997,102.061,96.061,68.061,21.060997,-50.939003,-100.939,-99.939,-99.939,-98.939,-99.939,-98.939,-98.939,-98.939,-98.939,-97.939,-97.939,-97.939,-96.939,-93.939,-88.939,-89.939,-88.939,-80.939,-76.939,-72.939,-71.939,-68.939,-65.939,-60.939003,-52.939003,-42.939003,-37.939003,-36.939003,-31.939003,-25.939003,-15.939003,-15.939003,-14.939003,-10.939003,-7.939003,-4.939003,-1.939003,-1.939003,-4.939003,-16.939003,-24.939003,-23.939003,-24.939003,-26.939003,-30.939003,-33.939003,-37.939003,-39.939003,-38.939003,-29.939003,-25.939003,-22.939003,-19.939003,-16.939003,-14.939003,-6.939003,-3.939003,-13.939003,-16.939003,-13.939003,3.060997,22.060997,43.060997,41.060997,37.060997,35.060997,39.060997,44.060997,48.060997,50.060997,53.060997,51.060997,48.060997,48.060997,46.060997,41.060997,33.060997,32.060997,38.060997,44.060997,40.060997,20.060997,-14.939003,-49.939003,-50.939003,-39.939003,-15.939003,-5.939003,3.060997,10.060997,13.060997,15.060997,18.060997,16.060997,9.060997,15.060997,22.060997,26.060997,27.060997,27.060997,23.060997,32.060997,56.060997,46.060997,29.060997,4.060997,-3.939003,-4.939003,9.060997,13.060997,7.060997,6.060997,5.060997,3.060997,1.060997,-0.939003,-7.939003,-14.939003,-19.939003,-29.939003,-41.939003,-60.939003,-60.939003,-49.939003,-29.939003,-19.939003,-18.939003,-22.939003,-21.939003,-7.939003,11.060997,29.060997,26.060997,9.060997,-23.939003,-33.939003,-40.939003,-42.939003,-47.939003,-54.939003,-60.939003,-67.939,-77.939,-77.939,-78.939,-82.939,-85.939,-89.939,-90.939,-92.939,-97.939,-98.939,-98.939,-97.939,-96.939,-95.939,-94.939,-92.939,-88.939,-86.939,-83.939,-82.939,-76.939,-69.939,-64.939,-61.939003,-58.939003,-51.939003,-44.939003,-43.939003,-39.939003,-36.939003,-44.939003,-54.939003,-68.939,-53.939003,-37.939003,-23.939003,-17.939003,-13.939003,-18.939003,-20.939003,-23.939003,-29.939003,-33.939003,-32.939003,-32.939003,-34.939003,-48.939003,-57.939003,-61.939003,-64.939,-68.939,-76.939,-82.939,-89.939,-88.939,-89.939,-95.939,-98.939,-101.939,-101.939,-100.939,-96.939,-60.939003,-43.939003,-45.939003,-46.939003,-45.939003,-40.939003,-49.939003,-61.939003,-45.939003,-16.939003,23.060997,-5.939003,-28.939003,-22.939003,-19.939003,-18.939003,-24.939003,-29.939003,-34.939003,-38.939003,-41.939003,-39.939003,-47.939003,-60.939003,-64.939,-64.939,-63.939003,-64.939,-64.939,-64.939,-67.939,-70.939,-70.939,-67.939,-61.939003,-62.939003,-60.939003,-50.939003,-39.939003,-30.939003,-33.939003,-33.939003,-30.939003,-30.939003,-30.939003,-29.939003,-31.939003,-35.939003,-35.939003,-34.939003,-32.939003,-27.939003,-22.939003,-20.939003,-19.939003,-20.939003,-19.939003,-20.939003,-24.939003,-28.939003,-33.939003,-40.939003,-46.939003,-52.939003,-48.939003,-51.939003,-61.939003,-60.939003,-56.939003,-52.939003,-54.939003,-56.939003,-55.939003,-56.939003,-57.939003,-56.939003,-53.939003,-49.939003,-47.939003,-48.939003,-49.939003,-48.939003,-43.939003,-43.939003,-42.939003,-41.939003,-38.939003,-35.939003,-36.939003,-38.939003,-41.939003,-50.939003,-58.939003,-59.939003,-62.939003,-64.939,-69.939,-71.939,-69.939,-71.939,-73.939,-73.939,-75.939,-77.939,-78.939,-80.939,-84.939,-66.939,-59.939003,-75.939,-88.939,-97.939,-93.939,-93.939,-97.939,-94.939,-93.939,-96.939,-97.939,-97.939,-98.939,-99.939,-98.939,-99.939,-99.939,-99.939,-98.939,-97.939,-98.939,-64.939,4.060997,37.060997,46.060997,1.060997,-50.939003,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-101.939,-100.939,-98.939,-97.939,-100.939,-99.939,-97.939,-89.939,-88.939,-95.939,-94.939,-92.939,-94.939,-96.939,-99.939,-101.939,-85.939,-51.939003,-42.939003,-40.939003,-46.939003,-47.939003,-45.939003,-45.939003,-45.939003,-44.939003,-42.939003,-41.939003,-46.939003,-47.939003,-48.939003,-53.939003,-52.939003,-47.939003,-51.939003,-56.939003,-63.939003,-64.939,-62.939003,-63.939003,-69.939,-81.939,-77.939,-76.939,-84.939,-90.939,-95.939,-91.939,-92.939,-97.939,-100.939,-94.939,-71.939,-59.939003,-51.939003,-58.939003,-59.939003,-54.939003,-48.939003,-45.939003,-53.939003,-52.939003,-48.939003,-47.939003,-45.939003,-44.939003,-46.939003,-48.939003,-51.939003,-49.939003,-46.939003,-46.939003,-48.939003,-53.939003,-52.939003,-52.939003,41.060997,39.060997,34.060997,24.060997,19.060997,19.060997,10.060997,3.060997,-1.939003,-6.939003,-10.939003,-15.939003,-21.939003,-30.939003,-31.939003,-30.939003,-26.939003,-24.939003,-22.939003,-21.939003,-22.939003,-24.939003,-20.939003,-18.939003,-19.939003,-17.939003,-15.939003,-20.939003,-23.939003,-25.939003,-28.939003,-32.939003,-38.939003,-43.939003,-48.939003,-47.939003,-49.939003,-53.939003,-58.939003,-64.939,-70.939,-52.939003,-27.939003,-32.939003,-48.939003,-73.939,-77.939,-78.939,-81.939,-82.939,-84.939,-85.939,-87.939,-90.939,-94.939,-96.939,-98.939,-97.939,-95.939,-98.939,-99.939,-100.939,-98.939,-96.939,-98.939,-97.939,-96.939,-99.939,-99.939,-99.939,-96.939,-94.939,-95.939,-91.939,-84.939,-81.939,-81.939,-84.939,-73.939,-48.939003,9.060997,21.060997,13.060997,13.060997,3.060997,-16.939003,-36.939003,-46.939003,-36.939003,-33.939003,-33.939003,-27.939003,-23.939003,-21.939003,-20.939003,-19.939003,-19.939003,-16.939003,-12.939003,-15.939003,-18.939003,-20.939003,-20.939003,-21.939003,-24.939003,-28.939003,-32.939003,-36.939003,-39.939003,-39.939003,-42.939003,-47.939003,-52.939003,-56.939003,-59.939003,-62.939003,-65.939,-69.939,-70.939,-71.939,-72.939,-73.939,-75.939,-76.939,-79.939,-83.939,-84.939,-85.939,-88.939,-90.939,-94.939,-94.939,-95.939,-95.939,-97.939,-98.939,-99.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-42.939003,38.060997,63.060997,75.061,74.061,40.060997,-2.939003,-59.939003,-70.939,-62.939003,-60.939003,-56.939003,-47.939003,-45.939003,-42.939003,-39.939003,-36.939003,-34.939003,-30.939003,-25.939003,-17.939003,-16.939003,-15.939003,-13.939003,-15.939003,-17.939003,-14.939003,-14.939003,-17.939003,-19.939003,-20.939003,-20.939003,-22.939003,-26.939003,-29.939003,-33.939003,-39.939003,-43.939003,-47.939003,-45.939003,-49.939003,-53.939003,-55.939003,-58.939003,-61.939003,-65.939,-68.939,-69.939,-71.939,-73.939,-75.939,-76.939,-77.939,-79.939,-81.939,-82.939,-85.939,-88.939,-89.939,-91.939,-97.939,-76.939,-55.939003,-41.939003,-30.939003,-16.939003,19.060997,49.060997,74.061,64.061,54.060997,55.060997,56.060997,57.060997,55.060997,54.060997,52.060997,56.060997,43.060997,-6.939003,-4.939003,17.060997,11.060997,14.060997,25.060997,20.060997,13.060997,5.060997,1.060997,-2.939003,-6.939003,-10.939003,-14.939003,-15.939003,-16.939003,-19.939003,-19.939003,-17.939003,-16.939003,-19.939003,-25.939003,-20.939003,-16.939003,-19.939003,-19.939003,-17.939003,-15.939003,-13.939003,-9.939003,-10.939003,-7.939003,5.060997,20.060997,34.060997,30.060997,26.060997,24.060997,-23.939003,-60.939003,-61.939003,-64.939,-67.939,-68.939,-69.939,-70.939,-72.939,-73.939,-74.939,-75.939,-76.939,-78.939,-81.939,-84.939,-85.939,-86.939,-89.939,-91.939,-93.939,-95.939,-96.939,-98.939,-100.939,-101.939,-102.939,-102.939,-100.939,-99.939,-99.939,-101.939,-97.939,-93.939,-92.939,-92.939,-92.939,-81.939,-47.939003,11.060997,30.060997,39.060997,33.060997,27.060997,19.060997,-14.939003,-43.939003,-67.939,-45.939003,-27.939003,-30.939003,-31.939003,-30.939003,-26.939003,-26.939003,-32.939003,-32.939003,-30.939003,-27.939003,-22.939003,-19.939003,-16.939003,-13.939003,-14.939003,-14.939003,-15.939003,-18.939003,-23.939003,-28.939003,-29.939003,-31.939003,-33.939003,-41.939003,-47.939003,-48.939003,-48.939003,-49.939003,-51.939003,-55.939003,-59.939003,-60.939003,-61.939003,-68.939,-69.939,-69.939,-72.939,-65.939,-48.939003,-55.939003,-65.939,-75.939,-73.939,-67.939,-80.939,-87.939,-88.939,-90.939,-93.939,-94.939,-95.939,-96.939,-95.939,-94.939,-92.939,-98.939,-91.939,-51.939003,9.060997,71.061,55.060997,30.060997,-3.939003,-46.939003,-74.939,-64.939,-61.939003,-61.939003,-55.939003,-51.939003,-48.939003,-46.939003,-43.939003,-38.939003,-35.939003,-33.939003,-33.939003,-31.939003,-26.939003,-26.939003,-27.939003,-25.939003,-24.939003,-22.939003,-22.939003,-23.939003,-25.939003,-28.939003,-29.939003,-29.939003,-30.939003,-33.939003,-35.939003,-36.939003,-35.939003,-36.939003,-39.939003,-45.939003,-49.939003,-52.939003,-52.939003,-52.939003,-53.939003,-43.939003,-35.939003,-36.939003,-44.939003,-51.939003,-32.939003,-8.939003,18.060997,17.060997,17.060997,26.060997,27.060997,26.060997,31.060997,34.060997,34.060997,43.060997,40.060997,8.060997,-3.939003,-4.939003,17.060997,35.060997,51.060997,46.060997,38.060997,29.060997,28.060997,31.060997,27.060997,23.060997,18.060997,16.060997,13.060997,7.060997,1.060997,-5.939003,-9.939003,-14.939003,-17.939003,-14.939003,-16.939003,-31.939003,-44.939003,-52.939003,-43.939003,-32.939003,-20.939003,-12.939003,-7.939003,-6.939003,-2.939003,3.060997,1.060997,0.06099701,0.06099701,9.060997,14.060997,9.060997,14.060997,22.060997,22.060997,25.060997,34.060997,25.060997,16.060997,9.060997,7.060997,6.060997,7.060997,8.060997,10.060997,7.060997,4.060997,2.060997,0.06099701,-3.939003,-12.939003,-20.939003,-26.939003,-35.939003,-44.939003,-53.939003,-56.939003,-56.939003,-57.939003,-60.939003,-63.939003,-71.939,-64.939,-20.939003,16.060997,44.060997,29.060997,-5.939003,-62.939003,-76.939,-80.939,-75.939,-72.939,-71.939,-68.939,-68.939,-72.939,-69.939,-66.939,-61.939003,-58.939003,-57.939003,-55.939003,-53.939003,-50.939003,-46.939003,-43.939003,-42.939003,-41.939003,-40.939003,-37.939003,-36.939003,-36.939003,-31.939003,-28.939003,-29.939003,-33.939003,-38.939003,-39.939003,-40.939003,-39.939003,-40.939003,-41.939003,-43.939003,-49.939003,-54.939003,-31.939003,-15.939003,-5.939003,-14.939003,-31.939003,-62.939003,-72.939,-73.939,-75.939,-75.939,-76.939,-78.939,-80.939,-79.939,-79.939,-80.939,-85.939,-88.939,-89.939,-90.939,-91.939,-94.939,-95.939,-95.939,-96.939,-93.939,-86.939,-83.939,-81.939,-79.939,-75.939,-68.939,-47.939003,-39.939003,-45.939003,-47.939003,-47.939003,-42.939003,-42.939003,-46.939003,-44.939003,-41.939003,-37.939003,-40.939003,-43.939003,-46.939003,-50.939003,-53.939003,-49.939003,-48.939003,-51.939003,-50.939003,-48.939003,-47.939003,-47.939003,-48.939003,-41.939003,-37.939003,-37.939003,-33.939003,-29.939003,-25.939003,-22.939003,-21.939003,-15.939003,-12.939003,-14.939003,-7.939003,-3.939003,-5.939003,-27.939003,-53.939003,-51.939003,-49.939003,-46.939003,-47.939003,-48.939003,-46.939003,-43.939003,-40.939003,-37.939003,-33.939003,-29.939003,-24.939003,-20.939003,-19.939003,-18.939003,-19.939003,-21.939003,-21.939003,-19.939003,-31.939003,-34.939003,-15.939003,0.06099701,14.060997,2.060997,-3.939003,-3.939003,-7.939003,-9.939003,-11.939003,-9.939003,-6.939003,-9.939003,-11.939003,-12.939003,-12.939003,-16.939003,-30.939003,-32.939003,-30.939003,-19.939003,-14.939003,-14.939003,-18.939003,-21.939003,-21.939003,-20.939003,-19.939003,-19.939003,-25.939003,-38.939003,-49.939003,-55.939003,-54.939003,-54.939003,-54.939003,-56.939003,-56.939003,-54.939003,-55.939003,-55.939003,-52.939003,-52.939003,-52.939003,-52.939003,-54.939003,-58.939003,-51.939003,-48.939003,-50.939003,-56.939003,-62.939003,-58.939003,-56.939003,-56.939003,-57.939003,-57.939003,-56.939003,-56.939003,-56.939003,-55.939003,-57.939003,-60.939003,-59.939003,-58.939003,-60.939003,-60.939003,-60.939003,-60.939003,-49.939003,-28.939003,-19.939003,-17.939003,-33.939003,-49.939003,-62.939003,-62.939003,-64.939,-71.939,-69.939,-69.939,-72.939,-73.939,-73.939,-75.939,-76.939,-75.939,-74.939,-74.939,-78.939,-84.939,-88.939,-88.939,-92.939,-100.939,-97.939,-94.939,-95.939,-97.939,-100.939,-102.939,-90.939,-65.939,-63.939003,-61.939003,-49.939003,-42.939003,-37.939003,-42.939003,-42.939003,-37.939003,-38.939003,-40.939003,-42.939003,-47.939003,-53.939003,-55.939003,-52.939003,-46.939003,-48.939003,-50.939003,-52.939003,-54.939003,-57.939003,-58.939003,-64.939,-74.939,-72.939,-69.939,-67.939,-71.939,-76.939,-76.939,-77.939,-79.939,-80.939,-81.939,-85.939,-86.939,-85.939,-84.939,-84.939,-86.939,-78.939,-70.939,-66.939,-71.939,-80.939,-80.939,-78.939,-77.939,-75.939,-76.939,-81.939,-72.939,-58.939003,-40.939003,-31.939003,-31.939003,-31.939003,-32.939003,1.060997,1.060997,0.06099701,-8.939003,-13.939003,-14.939003,-20.939003,-25.939003,-25.939003,-28.939003,-33.939003,-36.939003,-35.939003,-30.939003,-26.939003,-22.939003,-20.939003,-20.939003,-19.939003,-22.939003,-27.939003,-31.939003,-30.939003,-29.939003,-34.939003,-36.939003,-39.939003,-44.939003,-47.939003,-51.939003,-57.939003,-63.939003,-67.939,-74.939,-81.939,-80.939,-81.939,-86.939,-92.939,-98.939,-101.939,-73.939,-35.939003,-37.939003,-58.939003,-99.939,-103.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-96.939,-91.939,-90.939,-89.939,-88.939,-85.939,-81.939,-81.939,-78.939,-75.939,-75.939,-72.939,-69.939,-67.939,-66.939,-64.939,-59.939003,-53.939003,-47.939003,-46.939003,-48.939003,-40.939003,-30.939003,-14.939003,-11.939003,-14.939003,-17.939003,-22.939003,-28.939003,-29.939003,-26.939003,-16.939003,-15.939003,-19.939003,-17.939003,-17.939003,-18.939003,-17.939003,-18.939003,-23.939003,-24.939003,-24.939003,-29.939003,-34.939003,-39.939003,-42.939003,-45.939003,-48.939003,-54.939003,-61.939003,-67.939,-71.939,-71.939,-76.939,-82.939,-87.939,-91.939,-94.939,-96.939,-99.939,-101.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-100.939,-98.939,-96.939,-95.939,-93.939,-92.939,-90.939,-88.939,-86.939,-83.939,-82.939,-81.939,-79.939,-76.939,-73.939,-72.939,-70.939,-67.939,-34.939003,8.060997,22.060997,25.060997,20.060997,2.060997,-16.939003,-37.939003,-37.939003,-30.939003,-30.939003,-26.939003,-19.939003,-19.939003,-18.939003,-17.939003,-15.939003,-15.939003,-14.939003,-13.939003,-10.939003,-12.939003,-14.939003,-14.939003,-17.939003,-22.939003,-23.939003,-26.939003,-31.939003,-36.939003,-40.939003,-41.939003,-46.939003,-51.939003,-56.939003,-62.939003,-71.939,-77.939,-81.939,-80.939,-84.939,-88.939,-90.939,-94.939,-97.939,-101.939,-103.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-99.939,-97.939,-95.939,-95.939,-95.939,-92.939,-90.939,-90.939,-70.939,-51.939003,-42.939003,-31.939003,-17.939003,6.060997,26.060997,41.060997,33.060997,24.060997,21.060997,20.060997,20.060997,17.060997,16.060997,17.060997,14.060997,6.060997,-16.939003,-15.939003,-6.939003,-8.939003,-8.939003,-7.939003,-10.939003,-13.939003,-13.939003,-16.939003,-19.939003,-21.939003,-22.939003,-23.939003,-20.939003,-19.939003,-20.939003,-16.939003,-11.939003,-8.939003,-7.939003,-10.939003,-7.939003,-5.939003,-9.939003,-1.939003,10.060997,10.060997,16.060997,29.060997,22.060997,12.060997,2.060997,16.060997,36.060997,30.060997,27.060997,27.060997,-22.939003,-68.939,-93.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-101.939,-99.939,-96.939,-95.939,-93.939,-92.939,-90.939,-88.939,-85.939,-83.939,-82.939,-80.939,-77.939,-76.939,-74.939,-71.939,-66.939,-64.939,-65.939,-61.939003,-59.939003,-58.939003,-56.939003,-54.939003,-49.939003,-30.939003,4.060997,26.060997,40.060997,35.060997,29.060997,19.060997,-13.939003,-38.939003,-56.939003,-27.939003,-6.939003,-12.939003,-14.939003,-16.939003,-25.939003,-29.939003,-29.939003,-25.939003,-21.939003,-22.939003,-24.939003,-27.939003,-28.939003,-29.939003,-31.939003,-33.939003,-36.939003,-41.939003,-48.939003,-55.939003,-57.939003,-60.939003,-64.939,-72.939,-78.939,-82.939,-83.939,-84.939,-88.939,-91.939,-95.939,-96.939,-97.939,-102.939,-103.939,-103.939,-103.939,-91.939,-69.939,-72.939,-80.939,-91.939,-85.939,-72.939,-81.939,-85.939,-84.939,-80.939,-78.939,-77.939,-75.939,-74.939,-72.939,-69.939,-66.939,-66.939,-60.939003,-40.939003,-11.939003,17.060997,5.060997,-6.939003,-20.939003,-35.939003,-43.939003,-35.939003,-34.939003,-35.939003,-30.939003,-27.939003,-26.939003,-26.939003,-25.939003,-22.939003,-20.939003,-19.939003,-22.939003,-22.939003,-20.939003,-20.939003,-21.939003,-22.939003,-23.939003,-22.939003,-23.939003,-24.939003,-27.939003,-30.939003,-33.939003,-36.939003,-38.939003,-39.939003,-39.939003,-39.939003,-39.939003,-37.939003,-37.939003,-43.939003,-48.939003,-53.939003,-53.939003,-52.939003,-50.939003,-33.939003,-20.939003,-23.939003,-31.939003,-37.939003,-20.939003,3.060997,31.060997,34.060997,34.060997,39.060997,36.060997,32.060997,35.060997,36.060997,35.060997,39.060997,33.060997,4.060997,-7.939003,-10.939003,5.060997,16.060997,23.060997,16.060997,8.060997,-0.939003,-1.939003,0.06099701,-5.939003,-11.939003,-15.939003,-16.939003,-18.939003,-25.939003,-30.939003,-34.939003,-31.939003,-35.939003,-43.939003,-45.939003,-49.939003,-59.939003,-57.939003,-49.939003,-34.939003,-25.939003,-21.939003,-16.939003,-12.939003,-11.939003,-10.939003,-9.939003,-10.939003,-11.939003,-12.939003,-3.939003,3.060997,5.060997,7.060997,7.060997,8.060997,10.060997,15.060997,10.060997,7.060997,9.060997,10.060997,9.060997,3.060997,-0.939003,-0.939003,-1.939003,-2.939003,-5.939003,-12.939003,-21.939003,-29.939003,-33.939003,-31.939003,-39.939003,-47.939003,-57.939003,-59.939003,-61.939003,-69.939,-78.939,-86.939,-86.939,-72.939,-31.939003,-3.939003,14.060997,1.060997,-24.939003,-63.939003,-70.939,-70.939,-64.939,-59.939003,-54.939003,-51.939003,-51.939003,-52.939003,-50.939003,-47.939003,-43.939003,-42.939003,-41.939003,-39.939003,-37.939003,-33.939003,-31.939003,-30.939003,-31.939003,-31.939003,-31.939003,-30.939003,-31.939003,-35.939003,-33.939003,-33.939003,-34.939003,-40.939003,-48.939003,-50.939003,-52.939003,-53.939003,-56.939003,-59.939003,-62.939003,-69.939,-73.939,-18.939003,15.060997,29.060997,4.060997,-32.939003,-84.939,-98.939,-97.939,-95.939,-94.939,-92.939,-90.939,-89.939,-87.939,-85.939,-84.939,-83.939,-82.939,-81.939,-80.939,-78.939,-77.939,-75.939,-72.939,-73.939,-68.939,-58.939003,-57.939003,-55.939003,-50.939003,-48.939003,-45.939003,-41.939003,-41.939003,-46.939003,-47.939003,-46.939003,-43.939003,-42.939003,-43.939003,-42.939003,-43.939003,-46.939003,-44.939003,-43.939003,-44.939003,-45.939003,-45.939003,-40.939003,-37.939003,-36.939003,-41.939003,-42.939003,-37.939003,-28.939003,-18.939003,-12.939003,-10.939003,-11.939003,-10.939003,-7.939003,-3.939003,0.06099701,3.060997,6.060997,5.060997,-1.939003,5.060997,8.060997,0.06099701,-28.939003,-60.939003,-54.939003,-51.939003,-49.939003,-50.939003,-49.939003,-46.939003,-44.939003,-42.939003,-39.939003,-34.939003,-28.939003,-22.939003,-19.939003,-20.939003,-19.939003,-18.939003,-20.939003,-20.939003,-18.939003,-32.939003,-33.939003,-3.939003,22.060997,41.060997,23.060997,16.060997,21.060997,16.060997,12.060997,10.060997,15.060997,23.060997,20.060997,17.060997,14.060997,17.060997,10.060997,-17.939003,-21.939003,-13.939003,8.060997,17.060997,13.060997,11.060997,8.060997,6.060997,5.060997,5.060997,9.060997,-4.939003,-38.939003,-58.939003,-69.939,-68.939,-67.939,-66.939,-67.939,-66.939,-64.939,-64.939,-64.939,-60.939003,-60.939003,-59.939003,-59.939003,-60.939003,-63.939003,-56.939003,-51.939003,-55.939003,-58.939003,-60.939003,-57.939003,-55.939003,-55.939003,-55.939003,-54.939003,-53.939003,-53.939003,-54.939003,-52.939003,-53.939003,-56.939003,-54.939003,-52.939003,-53.939003,-55.939003,-55.939003,-52.939003,-45.939003,-35.939003,-32.939003,-34.939003,-45.939003,-50.939003,-53.939003,-51.939003,-52.939003,-57.939003,-55.939003,-55.939003,-58.939003,-59.939003,-59.939003,-59.939003,-60.939003,-60.939003,-59.939003,-57.939003,-60.939003,-71.939,-83.939,-89.939,-94.939,-100.939,-98.939,-97.939,-98.939,-99.939,-101.939,-102.939,-93.939,-71.939,-71.939,-68.939,-51.939003,-42.939003,-38.939003,-42.939003,-42.939003,-35.939003,-38.939003,-42.939003,-42.939003,-47.939003,-53.939003,-52.939003,-50.939003,-47.939003,-48.939003,-49.939003,-49.939003,-50.939003,-52.939003,-54.939003,-58.939003,-61.939003,-62.939003,-62.939003,-57.939003,-57.939003,-61.939003,-61.939003,-61.939003,-61.939003,-65.939,-72.939,-88.939,-92.939,-89.939,-88.939,-88.939,-90.939,-83.939,-74.939,-67.939,-75.939,-88.939,-93.939,-95.939,-95.939,-93.939,-93.939,-98.939,-85.939,-64.939,-36.939003,-20.939003,-17.939003,-16.939003,-17.939003,-54.939003,-48.939003,-40.939003,-38.939003,-35.939003,-30.939003,-25.939003,-20.939003,-10.939003,-7.939003,-7.939003,2.060997,7.060997,8.060997,-30.939003,-61.939003,-65.939,-65.939,-65.939,-77.939,-84.939,-86.939,-85.939,-86.939,-91.939,-93.939,-95.939,-95.939,-94.939,-96.939,-98.939,-99.939,-100.939,-100.939,-101.939,-100.939,-98.939,-97.939,-100.939,-103.939,-102.939,-80.939,-50.939003,-34.939003,-50.939003,-98.939,-102.939,-101.939,-102.939,-102.939,-103.939,-101.939,-99.939,-98.939,-96.939,-95.939,-94.939,-87.939,-78.939,-69.939,-64.939,-63.939003,-59.939003,-54.939003,-50.939003,-44.939003,-38.939003,-27.939003,-17.939003,-10.939003,-12.939003,-12.939003,-3.939003,-2.939003,-2.939003,4.060997,8.060997,9.060997,8.060997,2.060997,-14.939003,-18.939003,-16.939003,-10.939003,-8.939003,-13.939003,-27.939003,-38.939003,-35.939003,-40.939003,-48.939003,-58.939003,-65.939,-70.939,-72.939,-76.939,-82.939,-89.939,-96.939,-96.939,-97.939,-97.939,-97.939,-98.939,-98.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-101.939,-101.939,-101.939,-100.939,-101.939,-102.939,-103.939,-103.939,-103.939,-102.939,-101.939,-101.939,-99.939,-95.939,-88.939,-82.939,-79.939,-76.939,-72.939,-67.939,-63.939003,-57.939003,-48.939003,-42.939003,-37.939003,-31.939003,-23.939003,-13.939003,-10.939003,-5.939003,4.060997,-8.939003,-30.939003,-44.939003,-53.939003,-54.939003,-48.939003,-35.939003,-9.939003,-1.939003,-1.939003,-7.939003,-11.939003,-13.939003,-18.939003,-24.939003,-30.939003,-35.939003,-41.939003,-49.939003,-58.939003,-65.939,-71.939,-77.939,-80.939,-84.939,-88.939,-93.939,-95.939,-96.939,-97.939,-97.939,-96.939,-97.939,-98.939,-99.939,-99.939,-100.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-98.939,-92.939,-85.939,-79.939,-79.939,-80.939,-72.939,-66.939,-63.939003,-47.939003,-36.939003,-37.939003,-28.939003,-15.939003,-15.939003,-16.939003,-19.939003,-23.939003,-30.939003,-37.939003,-42.939003,-44.939003,-41.939003,-42.939003,-45.939003,-46.939003,-40.939003,-23.939003,-20.939003,-22.939003,-22.939003,-17.939003,-10.939003,-5.939003,-2.939003,1.060997,4.060997,8.060997,13.060997,18.060997,25.060997,28.060997,32.060997,39.060997,44.060997,48.060997,52.060997,55.060997,57.060997,51.060997,39.060997,19.060997,40.060997,74.061,65.061,69.061,84.061,65.061,38.060997,3.060997,13.060997,39.060997,32.060997,27.060997,26.060997,-21.939003,-67.939,-95.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-101.939,-99.939,-96.939,-96.939,-92.939,-83.939,-78.939,-73.939,-71.939,-65.939,-57.939003,-49.939003,-43.939003,-40.939003,-33.939003,-25.939003,-24.939003,-19.939003,-15.939003,-4.939003,0.06099701,3.060997,3.060997,1.060997,-0.939003,5.060997,11.060997,-3.939003,-7.939003,-2.939003,21.060997,39.060997,40.060997,34.060997,23.060997,-12.939003,-41.939003,-63.939003,-42.939003,-27.939003,-35.939003,-43.939003,-49.939003,-12.939003,11.060997,23.060997,40.060997,35.060997,-23.939003,-65.939,-96.939,-96.939,-95.939,-94.939,-96.939,-97.939,-97.939,-98.939,-99.939,-99.939,-99.939,-99.939,-100.939,-101.939,-101.939,-101.939,-101.939,-102.939,-102.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-92.939,-72.939,-64.939,-64.939,-76.939,-65.939,-45.939003,-49.939003,-50.939003,-45.939003,-35.939003,-28.939003,-26.939003,-20.939003,-15.939003,-12.939003,-9.939003,-6.939003,0.06099701,2.060997,-15.939003,-37.939003,-59.939003,-52.939003,-42.939003,-31.939003,-16.939003,-7.939003,-13.939003,-18.939003,-20.939003,-25.939003,-28.939003,-32.939003,-38.939003,-43.939003,-47.939003,-51.939003,-54.939003,-61.939003,-67.939,-71.939,-71.939,-71.939,-71.939,-72.939,-73.939,-74.939,-73.939,-71.939,-67.939,-65.939,-64.939,-60.939003,-53.939003,-43.939003,-34.939003,-26.939003,-16.939003,-8.939003,-4.939003,-4.939003,-7.939003,-3.939003,0.06099701,3.060997,13.060997,20.060997,15.060997,14.060997,15.060997,6.060997,2.060997,1.060997,9.060997,13.060997,9.060997,2.060997,-3.939003,-8.939003,-10.939003,-11.939003,-18.939003,-23.939003,-25.939003,-27.939003,-29.939003,-31.939003,-36.939003,-41.939003,-48.939003,-54.939003,-55.939003,-51.939003,-46.939003,-51.939003,-53.939003,-50.939003,-48.939003,-47.939003,-50.939003,-49.939003,-45.939003,-32.939003,-30.939003,-39.939003,-50.939003,-57.939003,-62.939003,-53.939003,-40.939003,-23.939003,-16.939003,-18.939003,-15.939003,-10.939003,-3.939003,-10.939003,-22.939003,-15.939003,-18.939003,-30.939003,-24.939003,-10.939003,16.060997,6.060997,-16.939003,-18.939003,-14.939003,-1.939003,0.06099701,2.060997,4.060997,5.060997,6.060997,-2.939003,-14.939003,-25.939003,-19.939003,-15.939003,-20.939003,-36.939003,-55.939003,-60.939003,-52.939003,-32.939003,-40.939003,-52.939003,-71.939,-70.939,-63.939003,-65.939,-73.939,-86.939,-65.939,-47.939003,-40.939003,-48.939003,-61.939003,-59.939003,-47.939003,-25.939003,-15.939003,-10.939003,-11.939003,-8.939003,-4.939003,-8.939003,-14.939003,-18.939003,-20.939003,-22.939003,-28.939003,-35.939003,-41.939003,-42.939003,-43.939003,-45.939003,-52.939003,-59.939003,-66.939,-68.939,-69.939,-72.939,-77.939,-85.939,-92.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-98.939,-99.939,-99.939,-99.939,-100.939,-92.939,-4.939003,38.060997,37.060997,1.060997,-39.939003,-89.939,-95.939,-84.939,-80.939,-75.939,-71.939,-65.939,-60.939003,-54.939003,-50.939003,-47.939003,-42.939003,-39.939003,-38.939003,-34.939003,-29.939003,-24.939003,-22.939003,-21.939003,-18.939003,-14.939003,-11.939003,-18.939003,-22.939003,-15.939003,-19.939003,-28.939003,-43.939003,-49.939003,-47.939003,-44.939003,-42.939003,-44.939003,-49.939003,-53.939003,-38.939003,-22.939003,-3.939003,-17.939003,-26.939003,-16.939003,-6.939003,3.060997,3.060997,4.060997,8.060997,-11.939003,-23.939003,-10.939003,9.060997,29.060997,22.060997,15.060997,13.060997,6.060997,1.060997,0.06099701,2.060997,4.060997,-5.939003,-13.939003,-22.939003,-21.939003,-24.939003,-32.939003,-42.939003,-49.939003,-42.939003,-39.939003,-40.939003,-37.939003,-34.939003,-30.939003,-33.939003,-39.939003,-40.939003,-36.939003,-29.939003,-23.939003,-20.939003,-25.939003,-22.939003,-17.939003,-16.939003,-18.939003,-21.939003,-32.939003,-32.939003,-2.939003,17.060997,29.060997,12.060997,7.060997,12.060997,10.060997,7.060997,12.060997,21.060997,33.060997,33.060997,30.060997,23.060997,32.060997,26.060997,-11.939003,-12.939003,2.060997,33.060997,46.060997,42.060997,46.060997,48.060997,43.060997,39.060997,38.060997,52.060997,25.060997,-43.939003,-78.939,-100.939,-100.939,-100.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-79.939,-69.939,-90.939,-93.939,-91.939,-90.939,-90.939,-95.939,-90.939,-84.939,-86.939,-88.939,-93.939,-89.939,-88.939,-87.939,-84.939,-81.939,-80.939,-82.939,-84.939,-75.939,-52.939003,-17.939003,-3.939003,-4.939003,-34.939003,-55.939003,-72.939,-66.939,-62.939003,-59.939003,-58.939003,-58.939003,-58.939003,-58.939003,-57.939003,-52.939003,-51.939003,-55.939003,-52.939003,-48.939003,-45.939003,-61.939003,-82.939,-91.939,-95.939,-95.939,-99.939,-102.939,-101.939,-101.939,-102.939,-103.939,-92.939,-70.939,-66.939,-62.939003,-52.939003,-48.939003,-47.939003,-47.939003,-45.939003,-40.939003,-43.939003,-46.939003,-47.939003,-48.939003,-47.939003,-44.939003,-45.939003,-52.939003,-53.939003,-54.939003,-55.939003,-52.939003,-47.939003,-53.939003,-52.939003,-44.939003,-49.939003,-54.939003,-52.939003,-50.939003,-48.939003,-48.939003,-46.939003,-43.939003,-54.939003,-67.939,-81.939,-76.939,-64.939,-71.939,-71.939,-66.939,-63.939003,-59.939003,-57.939003,-64.939,-73.939,-86.939,-94.939,-99.939,-99.939,-100.939,-102.939,-87.939,-64.939,-33.939003,-15.939003,-10.939003,-7.939003,-6.939003,10.060997,15.060997,23.060997,26.060997,27.060997,28.060997,30.060997,31.060997,34.060997,35.060997,34.060997,43.060997,43.060997,36.060997,-26.939003,-77.939,-85.939,-87.939,-86.939,-94.939,-97.939,-98.939,-98.939,-98.939,-97.939,-97.939,-99.939,-98.939,-97.939,-99.939,-99.939,-99.939,-98.939,-96.939,-93.939,-92.939,-90.939,-86.939,-84.939,-83.939,-82.939,-67.939,-45.939003,-31.939003,-35.939003,-59.939003,-60.939003,-58.939003,-56.939003,-54.939003,-51.939003,-46.939003,-42.939003,-38.939003,-36.939003,-35.939003,-34.939003,-32.939003,-29.939003,-23.939003,-21.939003,-26.939003,-28.939003,-29.939003,-28.939003,-26.939003,-24.939003,-25.939003,-26.939003,-28.939003,-32.939003,-36.939003,-36.939003,-40.939003,-44.939003,-45.939003,-44.939003,-42.939003,-48.939003,-40.939003,4.060997,28.060997,44.060997,45.060997,42.060997,34.060997,-27.939003,-75.939,-74.939,-77.939,-80.939,-85.939,-89.939,-91.939,-92.939,-94.939,-96.939,-100.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-100.939,-96.939,-94.939,-93.939,-89.939,-84.939,-79.939,-77.939,-74.939,-71.939,-67.939,-63.939003,-57.939003,-53.939003,-49.939003,-45.939003,-42.939003,-39.939003,-34.939003,-29.939003,-25.939003,-24.939003,-25.939003,-22.939003,-18.939003,-15.939003,-17.939003,-20.939003,-20.939003,-19.939003,-19.939003,-20.939003,-20.939003,-18.939003,-20.939003,-24.939003,-32.939003,-25.939003,-11.939003,0.06099701,7.060997,10.060997,12.060997,-0.939003,-43.939003,-57.939003,-57.939003,-60.939003,-62.939003,-63.939003,-66.939,-69.939,-71.939,-74.939,-77.939,-81.939,-85.939,-89.939,-92.939,-94.939,-96.939,-98.939,-99.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-98.939,-97.939,-96.939,-93.939,-90.939,-86.939,-83.939,-81.939,-75.939,-69.939,-64.939,-59.939003,-55.939003,-53.939003,-51.939003,-48.939003,-45.939003,-40.939003,-36.939003,-34.939003,-31.939003,-25.939003,-23.939003,-23.939003,-23.939003,-22.939003,-22.939003,-20.939003,-18.939003,-19.939003,-17.939003,-15.939003,-14.939003,-11.939003,-8.939003,-5.939003,-4.939003,-2.939003,-1.939003,0.06099701,2.060997,5.060997,9.060997,14.060997,12.060997,-6.939003,-4.939003,4.060997,6.060997,19.060997,43.060997,47.060997,47.060997,43.060997,44.060997,48.060997,50.060997,52.060997,55.060997,56.060997,56.060997,59.060997,61.060997,61.060997,65.061,66.061,65.061,65.061,56.060997,21.060997,37.060997,70.061,60.060997,58.060997,64.061,49.060997,29.060997,1.060997,12.060997,36.060997,30.060997,26.060997,26.060997,-16.939003,-52.939003,-62.939003,-61.939003,-56.939003,-53.939003,-51.939003,-50.939003,-45.939003,-41.939003,-41.939003,-39.939003,-36.939003,-32.939003,-29.939003,-26.939003,-25.939003,-24.939003,-24.939003,-21.939003,-19.939003,-19.939003,-18.939003,-16.939003,-17.939003,-19.939003,-23.939003,-25.939003,-26.939003,-24.939003,-24.939003,-26.939003,-30.939003,-35.939003,-42.939003,-43.939003,-42.939003,-52.939003,-38.939003,-0.939003,21.060997,37.060997,43.060997,36.060997,22.060997,-7.939003,-42.939003,-83.939,-77.939,-71.939,-74.939,-78.939,-75.939,-1.939003,44.060997,62.060997,70.061,51.060997,-25.939003,-72.939,-103.939,-101.939,-100.939,-99.939,-95.939,-91.939,-90.939,-87.939,-83.939,-83.939,-82.939,-78.939,-73.939,-68.939,-63.939003,-60.939003,-58.939003,-55.939003,-52.939003,-50.939003,-50.939003,-48.939003,-44.939003,-43.939003,-43.939003,-40.939003,-37.939003,-33.939003,-33.939003,-34.939003,-33.939003,-31.939003,-27.939003,-32.939003,-35.939003,-35.939003,-33.939003,-31.939003,-29.939003,-31.939003,-33.939003,-31.939003,-31.939003,-33.939003,-37.939003,-38.939003,-33.939003,-15.939003,5.060997,0.06099701,-4.939003,-11.939003,-30.939003,-43.939003,-40.939003,-40.939003,-39.939003,-37.939003,-35.939003,-31.939003,-29.939003,-27.939003,-26.939003,-26.939003,-26.939003,-25.939003,-23.939003,-20.939003,-15.939003,-10.939003,-7.939003,-7.939003,-8.939003,-10.939003,-13.939003,-15.939003,-12.939003,-11.939003,-11.939003,-9.939003,-4.939003,-0.939003,-0.939003,-2.939003,-1.939003,-0.939003,-1.939003,-6.939003,-12.939003,-11.939003,-15.939003,-21.939003,-16.939003,-14.939003,-15.939003,-14.939003,-13.939003,-20.939003,-25.939003,-29.939003,-29.939003,-30.939003,-34.939003,-35.939003,-36.939003,-36.939003,-35.939003,-32.939003,-33.939003,-32.939003,-29.939003,-26.939003,-25.939003,-27.939003,-27.939003,-25.939003,-22.939003,-19.939003,-18.939003,-16.939003,-13.939003,-10.939003,-8.939003,-5.939003,-3.939003,-2.939003,-2.939003,5.060997,12.060997,-5.939003,-29.939003,-58.939003,-55.939003,-50.939003,-48.939003,-42.939003,-34.939003,-32.939003,-23.939003,-9.939003,-17.939003,-23.939003,-17.939003,-18.939003,-22.939003,-26.939003,-35.939003,-49.939003,-21.939003,-0.939003,-8.939003,-12.939003,-13.939003,-7.939003,-13.939003,-32.939003,-26.939003,-22.939003,-34.939003,-28.939003,-15.939003,-5.939003,-4.939003,-10.939003,-17.939003,-22.939003,-23.939003,-33.939003,-47.939003,-59.939003,-58.939003,-44.939003,-38.939003,-40.939003,-62.939003,-71.939,-73.939,-64.939,-63.939003,-71.939,-69.939,-59.939003,-31.939003,-25.939003,-27.939003,-21.939003,-28.939003,-49.939003,-56.939003,-58.939003,-54.939003,-53.939003,-54.939003,-59.939003,-62.939003,-64.939,-65.939,-67.939,-70.939,-73.939,-76.939,-76.939,-77.939,-78.939,-81.939,-84.939,-87.939,-89.939,-89.939,-88.939,-88.939,-88.939,-88.939,-87.939,-84.939,-83.939,-81.939,-78.939,-75.939,-72.939,-69.939,-65.939,-62.939003,-64.939,-64.939,-35.939003,-21.939003,-24.939003,-30.939003,-37.939003,-47.939003,-47.939003,-44.939003,-39.939003,-37.939003,-38.939003,-40.939003,-40.939003,-35.939003,-32.939003,-32.939003,-37.939003,-43.939003,-48.939003,-46.939003,-45.939003,-45.939003,-51.939003,-56.939003,-55.939003,-54.939003,-54.939003,-61.939003,-66.939,-63.939003,-64.939,-65.939,-52.939003,-46.939003,-46.939003,-45.939003,-44.939003,-42.939003,-52.939003,-62.939003,-35.939003,-10.939003,12.060997,-9.939003,-26.939003,-18.939003,-15.939003,-13.939003,-11.939003,-11.939003,-15.939003,-29.939003,-37.939003,-29.939003,-25.939003,-22.939003,-27.939003,-31.939003,-33.939003,-32.939003,-32.939003,-36.939003,-37.939003,-39.939003,-43.939003,-43.939003,-43.939003,-41.939003,-41.939003,-41.939003,-41.939003,-42.939003,-42.939003,-41.939003,-41.939003,-39.939003,-35.939003,-33.939003,-36.939003,-41.939003,-41.939003,-37.939003,-30.939003,-24.939003,-21.939003,-24.939003,-20.939003,-15.939003,-17.939003,-19.939003,-22.939003,-28.939003,-33.939003,-30.939003,-27.939003,-25.939003,-30.939003,-31.939003,-29.939003,-31.939003,-33.939003,-32.939003,-27.939003,-21.939003,-20.939003,-20.939003,-23.939003,-21.939003,-22.939003,-32.939003,-34.939003,-32.939003,-22.939003,-15.939003,-13.939003,-11.939003,-11.939003,-12.939003,-10.939003,-7.939003,0.06099701,-11.939003,-44.939003,-63.939003,-74.939,-73.939,-74.939,-77.939,-78.939,-79.939,-77.939,-79.939,-81.939,-82.939,-83.939,-84.939,-86.939,-86.939,-86.939,-69.939,-62.939003,-84.939,-89.939,-88.939,-87.939,-88.939,-93.939,-90.939,-87.939,-88.939,-92.939,-98.939,-96.939,-95.939,-96.939,-95.939,-95.939,-94.939,-95.939,-96.939,-90.939,-55.939003,9.060997,27.060997,21.060997,-28.939003,-64.939,-90.939,-88.939,-85.939,-84.939,-83.939,-83.939,-83.939,-83.939,-83.939,-80.939,-80.939,-82.939,-80.939,-78.939,-77.939,-84.939,-93.939,-90.939,-89.939,-92.939,-97.939,-99.939,-94.939,-94.939,-96.939,-98.939,-85.939,-55.939003,-53.939003,-53.939003,-49.939003,-43.939003,-37.939003,-37.939003,-35.939003,-32.939003,-36.939003,-40.939003,-42.939003,-43.939003,-43.939003,-43.939003,-46.939003,-50.939003,-53.939003,-54.939003,-52.939003,-49.939003,-46.939003,-51.939003,-52.939003,-47.939003,-45.939003,-44.939003,-46.939003,-50.939003,-55.939003,-62.939003,-58.939003,-47.939003,-47.939003,-52.939003,-58.939003,-54.939003,-46.939003,-57.939003,-56.939003,-45.939003,-46.939003,-45.939003,-39.939003,-44.939003,-54.939003,-61.939003,-68.939,-76.939,-84.939,-90.939,-89.939,-81.939,-71.939,-56.939003,-46.939003,-42.939003,-36.939003,-33.939003,65.061,68.061,73.061,76.061,76.061,74.061,73.061,71.061,69.061,67.061,66.061,70.061,67.061,57.060997,-20.939003,-84.939,-96.939,-97.939,-95.939,-98.939,-98.939,-96.939,-96.939,-95.939,-89.939,-86.939,-86.939,-85.939,-84.939,-86.939,-83.939,-80.939,-78.939,-75.939,-70.939,-69.939,-67.939,-63.939003,-57.939003,-53.939003,-53.939003,-45.939003,-35.939003,-26.939003,-23.939003,-26.939003,-25.939003,-24.939003,-20.939003,-18.939003,-16.939003,-12.939003,-8.939003,-4.939003,-4.939003,-5.939003,-6.939003,-8.939003,-10.939003,-8.939003,-11.939003,-18.939003,-24.939003,-29.939003,-30.939003,-30.939003,-33.939003,-41.939003,-49.939003,-56.939003,-61.939003,-66.939,-74.939,-79.939,-85.939,-90.939,-91.939,-90.939,-97.939,-76.939,15.060997,64.061,93.061,88.061,81.061,72.061,-24.939003,-100.939,-102.939,-102.939,-101.939,-101.939,-100.939,-99.939,-98.939,-96.939,-95.939,-94.939,-93.939,-93.939,-92.939,-91.939,-89.939,-88.939,-87.939,-86.939,-85.939,-82.939,-77.939,-71.939,-69.939,-67.939,-63.939003,-55.939003,-48.939003,-44.939003,-41.939003,-38.939003,-33.939003,-27.939003,-20.939003,-14.939003,-10.939003,-7.939003,-4.939003,-2.939003,2.060997,6.060997,5.060997,0.06099701,-4.939003,-1.939003,0.06099701,1.060997,-6.939003,-14.939003,-19.939003,-23.939003,-26.939003,-32.939003,-37.939003,-41.939003,-46.939003,-55.939003,-73.939,-45.939003,2.060997,46.060997,69.061,69.061,69.061,36.060997,-69.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-100.939,-99.939,-98.939,-97.939,-96.939,-95.939,-93.939,-91.939,-91.939,-90.939,-89.939,-87.939,-86.939,-84.939,-83.939,-80.939,-75.939,-72.939,-71.939,-67.939,-62.939003,-56.939003,-52.939003,-49.939003,-41.939003,-35.939003,-28.939003,-22.939003,-17.939003,-16.939003,-13.939003,-11.939003,-8.939003,-4.939003,0.06099701,-1.939003,-2.939003,-0.939003,0.06099701,0.06099701,-5.939003,-9.939003,-13.939003,-17.939003,-17.939003,-11.939003,-13.939003,-18.939003,-6.939003,6.060997,21.060997,25.060997,29.060997,37.060997,43.060997,46.060997,48.060997,52.060997,59.060997,69.061,62.060997,11.060997,8.060997,24.060997,27.060997,46.060997,82.061,87.061,83.061,71.061,69.061,70.061,70.061,69.061,67.061,63.060997,60.060997,58.060997,57.060997,54.060997,56.060997,55.060997,50.060997,55.060997,49.060997,11.060997,18.060997,43.060997,34.060997,29.060997,28.060997,18.060997,8.060997,0.06099701,13.060997,33.060997,30.060997,27.060997,26.060997,-11.939003,-39.939003,-34.939003,-25.939003,-16.939003,-14.939003,-13.939003,-14.939003,-8.939003,-4.939003,-6.939003,-5.939003,-4.939003,0.06099701,0.06099701,-2.939003,-5.939003,-6.939003,-8.939003,-10.939003,-12.939003,-18.939003,-21.939003,-20.939003,-27.939003,-34.939003,-40.939003,-46.939003,-51.939003,-55.939003,-59.939003,-62.939003,-68.939,-75.939,-84.939,-89.939,-92.939,-96.939,-64.939,2.060997,22.060997,34.060997,43.060997,36.060997,20.060997,-3.939003,-41.939003,-96.939,-98.939,-96.939,-94.939,-93.939,-85.939,-1.939003,48.060997,67.061,65.061,42.060997,-27.939003,-64.939,-84.939,-81.939,-79.939,-77.939,-70.939,-64.939,-64.939,-60.939003,-56.939003,-56.939003,-54.939003,-49.939003,-43.939003,-36.939003,-30.939003,-27.939003,-24.939003,-21.939003,-18.939003,-16.939003,-17.939003,-16.939003,-11.939003,-10.939003,-11.939003,-8.939003,-8.939003,-12.939003,-18.939003,-21.939003,-14.939003,-17.939003,-23.939003,-26.939003,-29.939003,-31.939003,-33.939003,-34.939003,-31.939003,-35.939003,-40.939003,-39.939003,-41.939003,-43.939003,-54.939003,-56.939003,-34.939003,11.060997,60.060997,48.060997,32.060997,13.060997,-26.939003,-53.939003,-41.939003,-38.939003,-35.939003,-30.939003,-25.939003,-19.939003,-14.939003,-8.939003,-5.939003,-4.939003,-2.939003,2.060997,7.060997,14.060997,21.060997,28.060997,31.060997,30.060997,29.060997,26.060997,21.060997,16.060997,17.060997,18.060997,18.060997,18.060997,20.060997,18.060997,13.060997,6.060997,2.060997,-1.939003,-4.939003,-11.939003,-18.939003,-18.939003,-23.939003,-33.939003,-36.939003,-37.939003,-34.939003,-32.939003,-30.939003,-34.939003,-37.939003,-40.939003,-43.939003,-45.939003,-48.939003,-44.939003,-40.939003,-35.939003,-31.939003,-27.939003,-22.939003,-18.939003,-20.939003,-18.939003,-15.939003,-10.939003,-2.939003,9.060997,16.060997,20.060997,20.060997,19.060997,20.060997,25.060997,29.060997,31.060997,32.060997,33.060997,33.060997,42.060997,48.060997,4.060997,-35.939003,-71.939,-57.939003,-40.939003,-35.939003,-29.939003,-26.939003,-38.939003,-31.939003,-3.939003,-15.939003,-26.939003,-24.939003,-24.939003,-26.939003,-33.939003,-44.939003,-59.939003,-16.939003,6.060997,-31.939003,-30.939003,-11.939003,2.060997,-7.939003,-42.939003,-35.939003,-32.939003,-50.939003,-42.939003,-22.939003,-5.939003,3.060997,4.060997,-10.939003,-20.939003,-16.939003,-23.939003,-35.939003,-52.939003,-59.939003,-56.939003,-38.939003,-29.939003,-49.939003,-64.939,-75.939,-64.939,-58.939003,-56.939003,-71.939,-71.939,-31.939003,-3.939003,15.060997,18.060997,-11.939003,-75.939,-93.939,-100.939,-92.939,-92.939,-95.939,-97.939,-98.939,-98.939,-96.939,-95.939,-94.939,-93.939,-92.939,-91.939,-90.939,-90.939,-88.939,-87.939,-86.939,-86.939,-86.939,-82.939,-77.939,-72.939,-67.939,-62.939003,-60.939003,-57.939003,-55.939003,-51.939003,-48.939003,-44.939003,-39.939003,-35.939003,-33.939003,-37.939003,-43.939003,-52.939003,-56.939003,-56.939003,-46.939003,-36.939003,-27.939003,-26.939003,-27.939003,-24.939003,-26.939003,-30.939003,-37.939003,-41.939003,-38.939003,-37.939003,-39.939003,-49.939003,-59.939003,-67.939,-68.939,-69.939,-73.939,-82.939,-91.939,-90.939,-90.939,-92.939,-95.939,-98.939,-97.939,-95.939,-88.939,-57.939003,-43.939003,-46.939003,-46.939003,-45.939003,-40.939003,-51.939003,-62.939003,-35.939003,-12.939003,5.060997,-14.939003,-31.939003,-28.939003,-32.939003,-36.939003,-31.939003,-32.939003,-39.939003,-45.939003,-49.939003,-45.939003,-51.939003,-60.939003,-60.939003,-61.939003,-62.939003,-55.939003,-50.939003,-53.939003,-56.939003,-58.939003,-54.939003,-49.939003,-45.939003,-40.939003,-37.939003,-37.939003,-39.939003,-43.939003,-46.939003,-46.939003,-45.939003,-44.939003,-41.939003,-39.939003,-40.939003,-42.939003,-40.939003,-36.939003,-30.939003,-25.939003,-23.939003,-23.939003,-20.939003,-16.939003,-17.939003,-19.939003,-22.939003,-27.939003,-33.939003,-36.939003,-41.939003,-45.939003,-43.939003,-43.939003,-44.939003,-48.939003,-50.939003,-51.939003,-49.939003,-47.939003,-47.939003,-46.939003,-46.939003,-48.939003,-49.939003,-45.939003,-47.939003,-52.939003,-54.939003,-53.939003,-49.939003,-50.939003,-50.939003,-48.939003,-44.939003,-41.939003,-36.939003,-37.939003,-43.939003,-49.939003,-52.939003,-48.939003,-50.939003,-55.939003,-58.939003,-58.939003,-55.939003,-58.939003,-61.939003,-62.939003,-62.939003,-64.939,-66.939,-67.939,-67.939,-55.939003,-52.939003,-70.939,-75.939,-75.939,-73.939,-74.939,-77.939,-77.939,-78.939,-77.939,-81.939,-87.939,-87.939,-88.939,-89.939,-91.939,-91.939,-91.939,-92.939,-92.939,-90.939,-56.939003,10.060997,29.060997,22.060997,-29.939003,-66.939,-93.939,-93.939,-92.939,-93.939,-93.939,-94.939,-95.939,-95.939,-94.939,-95.939,-95.939,-94.939,-95.939,-96.939,-97.939,-97.939,-97.939,-87.939,-84.939,-89.939,-94.939,-95.939,-88.939,-88.939,-91.939,-94.939,-77.939,-40.939003,-43.939003,-47.939003,-41.939003,-32.939003,-23.939003,-25.939003,-26.939003,-26.939003,-29.939003,-33.939003,-35.939003,-37.939003,-38.939003,-41.939003,-44.939003,-46.939003,-50.939003,-52.939003,-48.939003,-47.939003,-47.939003,-51.939003,-52.939003,-50.939003,-44.939003,-41.939003,-49.939003,-58.939003,-67.939,-74.939,-68.939,-51.939003,-43.939003,-39.939003,-40.939003,-39.939003,-39.939003,-50.939003,-47.939003,-31.939003,-35.939003,-38.939003,-31.939003,-34.939003,-42.939003,-43.939003,-47.939003,-56.939003,-67.939,-76.939,-76.939,-75.939,-74.939,-71.939,-70.939,-69.939,-64.939,-60.939003,65.061,66.061,67.061,70.061,70.061,69.061,71.061,72.061,72.061,70.061,68.061,66.061,66.061,66.061,-11.939003,-76.939,-91.939,-89.939,-80.939,-84.939,-82.939,-74.939,-73.939,-70.939,-60.939003,-54.939003,-48.939003,-46.939003,-46.939003,-45.939003,-35.939003,-26.939003,-23.939003,-22.939003,-21.939003,-20.939003,-19.939003,-16.939003,-11.939003,-7.939003,-5.939003,-9.939003,-14.939003,-16.939003,-18.939003,-17.939003,-20.939003,-21.939003,-19.939003,-24.939003,-33.939003,-38.939003,-43.939003,-46.939003,-52.939003,-58.939003,-65.939,-68.939,-71.939,-75.939,-79.939,-83.939,-85.939,-86.939,-86.939,-85.939,-84.939,-88.939,-90.939,-91.939,-92.939,-93.939,-96.939,-98.939,-99.939,-99.939,-100.939,-100.939,-101.939,-78.939,0.06099701,57.060997,101.061,90.061,83.061,80.061,-14.939003,-91.939,-100.939,-100.939,-96.939,-93.939,-90.939,-86.939,-81.939,-75.939,-68.939,-64.939,-62.939003,-58.939003,-54.939003,-49.939003,-43.939003,-37.939003,-33.939003,-30.939003,-27.939003,-19.939003,-12.939003,-6.939003,-5.939003,-5.939003,-5.939003,-4.939003,-2.939003,-2.939003,-2.939003,-3.939003,-7.939003,-10.939003,-12.939003,-16.939003,-19.939003,-25.939003,-33.939003,-39.939003,-41.939003,-44.939003,-52.939003,-60.939003,-68.939,-66.939,-68.939,-73.939,-78.939,-82.939,-84.939,-84.939,-84.939,-84.939,-85.939,-89.939,-89.939,-91.939,-95.939,-62.939003,-10.939003,64.061,94.061,79.061,90.061,63.060997,-60.939003,-99.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-101.939,-100.939,-96.939,-91.939,-86.939,-82.939,-78.939,-74.939,-68.939,-61.939003,-53.939003,-51.939003,-46.939003,-40.939003,-35.939003,-30.939003,-23.939003,-20.939003,-17.939003,-11.939003,-7.939003,-4.939003,-3.939003,-3.939003,-3.939003,-0.939003,2.060997,-0.939003,-3.939003,-8.939003,-10.939003,-13.939003,-15.939003,-18.939003,-22.939003,-29.939003,-32.939003,-35.939003,-41.939003,-47.939003,-53.939003,-60.939003,-65.939,-66.939,-71.939,-79.939,-72.939,-59.939003,-31.939003,-27.939003,-29.939003,8.060997,43.060997,75.061,68.061,62.060997,69.061,69.061,67.061,69.061,68.061,64.061,80.061,77.061,24.060997,13.060997,18.060997,19.060997,38.060997,73.061,78.061,76.061,62.060997,54.060997,49.060997,48.060997,46.060997,41.060997,32.060997,24.060997,20.060997,15.060997,11.060997,9.060997,5.060997,0.06099701,-1.939003,-5.939003,-18.939003,-22.939003,-22.939003,-24.939003,-25.939003,-26.939003,-34.939003,-31.939003,-1.939003,17.060997,30.060997,33.060997,31.060997,24.060997,-10.939003,-34.939003,-26.939003,-16.939003,-5.939003,-16.939003,-24.939003,-29.939003,-37.939003,-43.939003,-44.939003,-49.939003,-53.939003,-54.939003,-60.939003,-67.939,-70.939,-71.939,-70.939,-75.939,-82.939,-83.939,-84.939,-84.939,-85.939,-87.939,-89.939,-89.939,-89.939,-91.939,-92.939,-91.939,-93.939,-97.939,-99.939,-100.939,-101.939,-101.939,-66.939,3.060997,20.060997,31.060997,38.060997,31.060997,15.060997,-4.939003,-40.939003,-91.939,-83.939,-71.939,-64.939,-62.939003,-57.939003,-26.939003,-9.939003,-7.939003,-10.939003,-16.939003,-26.939003,-25.939003,-20.939003,-16.939003,-12.939003,-8.939003,-5.939003,-3.939003,-9.939003,-12.939003,-14.939003,-17.939003,-18.939003,-18.939003,-21.939003,-23.939003,-25.939003,-26.939003,-26.939003,-29.939003,-35.939003,-41.939003,-46.939003,-49.939003,-53.939003,-56.939003,-58.939003,-60.939003,-56.939003,-44.939003,-43.939003,-48.939003,-58.939003,-54.939003,-44.939003,-41.939003,-38.939003,-34.939003,-30.939003,-25.939003,-16.939003,-10.939003,-3.939003,-4.939003,-1.939003,8.060997,5.060997,5.060997,13.060997,34.060997,59.060997,53.060997,47.060997,43.060997,29.060997,21.060997,30.060997,30.060997,26.060997,28.060997,24.060997,15.060997,10.060997,6.060997,2.060997,-1.939003,-5.939003,-9.939003,-11.939003,-12.939003,-12.939003,-13.939003,-16.939003,-22.939003,-28.939003,-28.939003,-29.939003,-34.939003,-34.939003,-32.939003,-27.939003,-26.939003,-28.939003,-30.939003,-28.939003,-21.939003,-17.939003,-13.939003,-10.939003,-12.939003,-14.939003,-3.939003,1.060997,1.060997,-7.939003,-13.939003,-6.939003,-2.939003,0.06099701,-2.939003,-4.939003,-4.939003,3.060997,9.060997,11.060997,16.060997,22.060997,26.060997,29.060997,31.060997,36.060997,33.060997,8.060997,2.060997,4.060997,24.060997,45.060997,67.061,60.060997,50.060997,41.060997,36.060997,32.060997,30.060997,27.060997,26.060997,27.060997,28.060997,22.060997,20.060997,15.060997,-30.939003,-57.939003,-66.939,-48.939003,-31.939003,-23.939003,-17.939003,-13.939003,-34.939003,-32.939003,-8.939003,-0.939003,2.060997,-5.939003,-23.939003,-41.939003,-29.939003,-32.939003,-49.939003,-15.939003,1.060997,-42.939003,-38.939003,-16.939003,1.060997,7.060997,2.060997,0.06099701,-1.939003,-4.939003,1.060997,8.060997,-0.939003,-1.939003,5.060997,4.060997,4.060997,9.060997,-1.939003,-21.939003,-38.939003,-53.939003,-64.939,-43.939003,-27.939003,-28.939003,-42.939003,-58.939003,-67.939,-63.939003,-48.939003,-66.939,-76.939,-52.939003,-1.939003,53.060997,36.060997,-11.939003,-90.939,-101.939,-102.939,-97.939,-93.939,-87.939,-82.939,-80.939,-80.939,-74.939,-67.939,-66.939,-61.939003,-55.939003,-49.939003,-46.939003,-45.939003,-42.939003,-39.939003,-35.939003,-33.939003,-31.939003,-30.939003,-28.939003,-24.939003,-23.939003,-24.939003,-25.939003,-25.939003,-23.939003,-25.939003,-27.939003,-32.939003,-31.939003,-33.939003,-42.939003,-48.939003,-50.939003,-25.939003,-8.939003,-0.939003,-15.939003,-38.939003,-72.939,-79.939,-76.939,-79.939,-82.939,-85.939,-88.939,-89.939,-88.939,-88.939,-88.939,-91.939,-93.939,-95.939,-95.939,-95.939,-96.939,-98.939,-100.939,-96.939,-93.939,-90.939,-85.939,-80.939,-79.939,-72.939,-63.939003,-48.939003,-43.939003,-48.939003,-48.939003,-46.939003,-41.939003,-39.939003,-39.939003,-40.939003,-43.939003,-47.939003,-44.939003,-43.939003,-50.939003,-51.939003,-51.939003,-51.939003,-49.939003,-44.939003,-44.939003,-44.939003,-40.939003,-38.939003,-35.939003,-32.939003,-31.939003,-30.939003,-24.939003,-19.939003,-18.939003,-12.939003,-4.939003,1.060997,1.060997,-4.939003,7.060997,11.060997,-10.939003,-37.939003,-63.939003,-58.939003,-55.939003,-52.939003,-52.939003,-51.939003,-47.939003,-45.939003,-42.939003,-36.939003,-31.939003,-27.939003,-26.939003,-25.939003,-25.939003,-24.939003,-21.939003,-18.939003,-18.939003,-21.939003,-36.939003,-33.939003,12.060997,27.060997,27.060997,19.060997,13.060997,10.060997,3.060997,-2.939003,-0.939003,5.060997,11.060997,5.060997,2.060997,2.060997,-0.939003,-8.939003,-30.939003,-31.939003,-21.939003,-10.939003,-9.939003,-18.939003,-17.939003,-15.939003,-13.939003,-17.939003,-22.939003,-16.939003,-22.939003,-38.939003,-51.939003,-56.939003,-47.939003,-48.939003,-53.939003,-54.939003,-53.939003,-48.939003,-49.939003,-49.939003,-48.939003,-47.939003,-45.939003,-45.939003,-45.939003,-47.939003,-43.939003,-41.939003,-46.939003,-46.939003,-46.939003,-45.939003,-42.939003,-41.939003,-46.939003,-47.939003,-43.939003,-45.939003,-48.939003,-48.939003,-49.939003,-52.939003,-52.939003,-53.939003,-51.939003,-53.939003,-56.939003,-58.939003,-54.939003,-47.939003,-39.939003,-35.939003,-44.939003,-52.939003,-59.939003,-58.939003,-57.939003,-58.939003,-60.939003,-62.939003,-66.939,-67.939,-66.939,-69.939,-69.939,-66.939,-70.939,-74.939,-75.939,-77.939,-79.939,-82.939,-86.939,-91.939,-90.939,-87.939,-83.939,-86.939,-92.939,-90.939,-70.939,-32.939003,-41.939003,-46.939003,-23.939003,-10.939003,-1.939003,-13.939003,-22.939003,-28.939003,-25.939003,-24.939003,-29.939003,-31.939003,-32.939003,-34.939003,-37.939003,-41.939003,-42.939003,-43.939003,-47.939003,-50.939003,-55.939003,-55.939003,-53.939003,-49.939003,-48.939003,-52.939003,-72.939,-80.939,-82.939,-75.939,-65.939,-51.939003,-43.939003,-38.939003,-36.939003,-44.939003,-54.939003,-59.939003,-52.939003,-32.939003,-40.939003,-49.939003,-47.939003,-49.939003,-51.939003,-47.939003,-46.939003,-45.939003,-51.939003,-58.939003,-66.939,-67.939,-65.939,-63.939003,-66.939,-74.939,-76.939,-77.939,58.060997,56.060997,53.060997,51.060997,49.060997,45.060997,43.060997,41.060997,38.060997,33.060997,27.060997,21.060997,18.060997,16.060997,-21.939003,-50.939003,-54.939003,-50.939003,-43.939003,-43.939003,-40.939003,-34.939003,-35.939003,-35.939003,-28.939003,-26.939003,-25.939003,-24.939003,-25.939003,-26.939003,-25.939003,-23.939003,-21.939003,-22.939003,-25.939003,-28.939003,-32.939003,-35.939003,-34.939003,-35.939003,-38.939003,-40.939003,-38.939003,-20.939003,-22.939003,-43.939003,-51.939003,-56.939003,-56.939003,-60.939003,-66.939,-71.939,-75.939,-78.939,-82.939,-87.939,-92.939,-93.939,-95.939,-98.939,-101.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-100.939,-99.939,-100.939,-100.939,-99.939,-99.939,-96.939,-94.939,-91.939,-88.939,-86.939,-84.939,-67.939,-16.939003,22.060997,52.060997,42.060997,35.060997,31.060997,-16.939003,-53.939003,-54.939003,-52.939003,-49.939003,-45.939003,-41.939003,-35.939003,-33.939003,-31.939003,-29.939003,-27.939003,-26.939003,-26.939003,-26.939003,-26.939003,-25.939003,-24.939003,-24.939003,-24.939003,-25.939003,-25.939003,-24.939003,-23.939003,-24.939003,-27.939003,-32.939003,-34.939003,-37.939003,-38.939003,-39.939003,-41.939003,-44.939003,-48.939003,-51.939003,-54.939003,-57.939003,-62.939003,-68.939,-73.939,-75.939,-78.939,-84.939,-89.939,-94.939,-93.939,-95.939,-99.939,-101.939,-103.939,-102.939,-101.939,-100.939,-100.939,-99.939,-99.939,-99.939,-99.939,-98.939,-69.939,-25.939003,47.060997,73.061,52.060997,61.060997,43.060997,-41.939003,-67.939,-69.939,-63.939003,-60.939003,-61.939003,-57.939003,-54.939003,-52.939003,-50.939003,-46.939003,-40.939003,-36.939003,-33.939003,-32.939003,-30.939003,-27.939003,-24.939003,-21.939003,-24.939003,-23.939003,-21.939003,-19.939003,-17.939003,-15.939003,-18.939003,-22.939003,-21.939003,-21.939003,-22.939003,-26.939003,-30.939003,-33.939003,-34.939003,-34.939003,-38.939003,-42.939003,-46.939003,-49.939003,-52.939003,-53.939003,-56.939003,-60.939003,-65.939,-68.939,-71.939,-75.939,-79.939,-83.939,-88.939,-93.939,-93.939,-96.939,-101.939,-92.939,-74.939,-33.939003,-26.939003,-28.939003,8.060997,43.060997,76.061,65.061,55.060997,57.060997,54.060997,50.060997,49.060997,45.060997,39.060997,46.060997,42.060997,8.060997,-0.939003,1.060997,1.060997,8.060997,22.060997,23.060997,21.060997,14.060997,9.060997,6.060997,4.060997,1.060997,-1.939003,-6.939003,-10.939003,-10.939003,-10.939003,-8.939003,-8.939003,-10.939003,-11.939003,-9.939003,-9.939003,-15.939003,-17.939003,-16.939003,-15.939003,-13.939003,-9.939003,-13.939003,-12.939003,0.06099701,15.060997,31.060997,34.060997,32.060997,26.060997,-8.939003,-39.939003,-53.939003,-52.939003,-46.939003,-54.939003,-60.939003,-64.939,-70.939,-75.939,-77.939,-80.939,-83.939,-85.939,-89.939,-94.939,-96.939,-96.939,-95.939,-98.939,-101.939,-100.939,-99.939,-98.939,-96.939,-94.939,-93.939,-91.939,-88.939,-85.939,-81.939,-76.939,-75.939,-73.939,-71.939,-70.939,-69.939,-68.939,-44.939003,1.060997,20.060997,33.060997,37.060997,31.060997,21.060997,-4.939003,-35.939003,-71.939,-53.939003,-35.939003,-31.939003,-31.939003,-32.939003,-27.939003,-23.939003,-25.939003,-24.939003,-24.939003,-28.939003,-28.939003,-26.939003,-26.939003,-25.939003,-23.939003,-22.939003,-21.939003,-23.939003,-25.939003,-26.939003,-29.939003,-29.939003,-27.939003,-28.939003,-27.939003,-25.939003,-23.939003,-22.939003,-22.939003,-23.939003,-23.939003,-21.939003,-19.939003,-21.939003,-22.939003,-22.939003,-19.939003,-14.939003,-7.939003,-10.939003,-13.939003,-8.939003,-13.939003,-22.939003,-11.939003,-3.939003,-0.939003,2.060997,5.060997,8.060997,10.060997,11.060997,8.060997,8.060997,12.060997,9.060997,6.060997,2.060997,7.060997,16.060997,11.060997,9.060997,9.060997,8.060997,7.060997,10.060997,10.060997,9.060997,9.060997,5.060997,-0.939003,-3.939003,-4.939003,-7.939003,-10.939003,-11.939003,-10.939003,-11.939003,-12.939003,-11.939003,-10.939003,-8.939003,-12.939003,-17.939003,-16.939003,-16.939003,-18.939003,-22.939003,-23.939003,-18.939003,-19.939003,-23.939003,-26.939003,-23.939003,-15.939003,-15.939003,-15.939003,-15.939003,-18.939003,-22.939003,-20.939003,-21.939003,-22.939003,-17.939003,-13.939003,-11.939003,-17.939003,-22.939003,-8.939003,9.060997,29.060997,32.060997,33.060997,35.060997,35.060997,34.060997,34.060997,33.060997,32.060997,35.060997,29.060997,-0.939003,-8.939003,-5.939003,9.060997,22.060997,35.060997,26.060997,16.060997,9.060997,3.060997,-1.939003,-8.939003,-10.939003,-10.939003,-10.939003,-10.939003,-11.939003,-12.939003,-18.939003,-52.939003,-64.939,-52.939003,-36.939003,-23.939003,-20.939003,-13.939003,-7.939003,-20.939003,-20.939003,-8.939003,5.060997,10.060997,-7.939003,-27.939003,-45.939003,-28.939003,-19.939003,-18.939003,-2.939003,0.06099701,-27.939003,-23.939003,-6.939003,6.060997,14.060997,18.060997,12.060997,7.060997,6.060997,9.060997,11.060997,-15.939003,-23.939003,-9.939003,4.060997,16.060997,23.060997,1.060997,-30.939003,-44.939003,-51.939003,-51.939003,-32.939003,-19.939003,-21.939003,-31.939003,-44.939003,-61.939003,-63.939003,-52.939003,-61.939003,-68.939,-68.939,-32.939003,11.060997,-0.939003,-24.939003,-60.939003,-62.939003,-59.939003,-57.939003,-54.939003,-51.939003,-47.939003,-46.939003,-45.939003,-42.939003,-39.939003,-39.939003,-37.939003,-36.939003,-37.939003,-36.939003,-35.939003,-36.939003,-37.939003,-38.939003,-40.939003,-42.939003,-44.939003,-46.939003,-49.939003,-50.939003,-52.939003,-53.939003,-53.939003,-53.939003,-55.939003,-57.939003,-61.939003,-61.939003,-63.939003,-70.939,-71.939,-64.939,-2.939003,29.060997,30.060997,-2.939003,-42.939003,-87.939,-95.939,-88.939,-89.939,-89.939,-89.939,-85.939,-81.939,-78.939,-77.939,-76.939,-74.939,-71.939,-70.939,-68.939,-66.939,-65.939,-64.939,-62.939003,-61.939003,-59.939003,-55.939003,-52.939003,-49.939003,-47.939003,-44.939003,-43.939003,-44.939003,-45.939003,-45.939003,-45.939003,-44.939003,-41.939003,-39.939003,-37.939003,-35.939003,-34.939003,-35.939003,-35.939003,-35.939003,-31.939003,-30.939003,-28.939003,-26.939003,-23.939003,-19.939003,-29.939003,-33.939003,-21.939003,-9.939003,1.060997,-2.939003,-3.939003,-2.939003,-5.939003,-8.939003,-8.939003,-2.939003,3.060997,3.060997,0.06099701,-6.939003,-1.939003,-2.939003,-17.939003,-35.939003,-51.939003,-48.939003,-45.939003,-41.939003,-40.939003,-39.939003,-36.939003,-38.939003,-39.939003,-34.939003,-30.939003,-25.939003,-24.939003,-24.939003,-25.939003,-25.939003,-22.939003,-20.939003,-20.939003,-21.939003,-34.939003,-31.939003,11.060997,25.060997,26.060997,18.060997,14.060997,15.060997,7.060997,1.060997,7.060997,15.060997,21.060997,20.060997,19.060997,19.060997,18.060997,8.060997,-20.939003,-17.939003,-0.939003,14.060997,18.060997,11.060997,12.060997,14.060997,15.060997,14.060997,11.060997,18.060997,-3.939003,-54.939003,-71.939,-77.939,-71.939,-71.939,-74.939,-75.939,-74.939,-71.939,-71.939,-71.939,-70.939,-69.939,-68.939,-67.939,-66.939,-65.939,-56.939003,-53.939003,-64.939,-66.939,-66.939,-65.939,-64.939,-62.939003,-65.939,-66.939,-63.939003,-62.939003,-62.939003,-63.939003,-64.939,-65.939,-65.939,-65.939,-62.939003,-65.939,-67.939,-58.939003,-47.939003,-34.939003,-31.939003,-34.939003,-49.939003,-57.939003,-60.939003,-59.939003,-58.939003,-59.939003,-58.939003,-59.939003,-61.939003,-61.939003,-58.939003,-59.939003,-60.939003,-60.939003,-63.939003,-63.939003,-63.939003,-69.939,-77.939,-83.939,-86.939,-87.939,-88.939,-87.939,-85.939,-90.939,-96.939,-92.939,-75.939,-45.939003,-44.939003,-41.939003,-20.939003,-10.939003,-5.939003,-13.939003,-19.939003,-22.939003,-19.939003,-19.939003,-24.939003,-28.939003,-31.939003,-30.939003,-32.939003,-38.939003,-37.939003,-38.939003,-44.939003,-50.939003,-56.939003,-58.939003,-54.939003,-44.939003,-50.939003,-57.939003,-65.939,-73.939,-79.939,-75.939,-66.939,-52.939003,-50.939003,-49.939003,-48.939003,-48.939003,-52.939003,-58.939003,-54.939003,-40.939003,-46.939003,-54.939003,-58.939003,-58.939003,-57.939003,-53.939003,-49.939003,-44.939003,-52.939003,-63.939003,-73.939,-74.939,-71.939,-69.939,-72.939,-78.939,-80.939,-80.939,43.060997,38.060997,32.060997,25.060997,20.060997,13.060997,5.060997,-1.939003,-8.939003,-16.939003,-25.939003,-32.939003,-39.939003,-47.939003,-35.939003,-21.939003,-11.939003,-6.939003,-4.939003,-0.939003,2.060997,2.060997,-0.939003,-3.939003,-3.939003,-8.939003,-14.939003,-16.939003,-18.939003,-22.939003,-33.939003,-42.939003,-42.939003,-46.939003,-52.939003,-61.939003,-70.939,-79.939,-82.939,-89.939,-99.939,-93.939,-77.939,-30.939003,-30.939003,-77.939,-93.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-97.939,-97.939,-97.939,-92.939,-91.939,-93.939,-91.939,-89.939,-86.939,-82.939,-76.939,-71.939,-65.939,-59.939003,-55.939003,-48.939003,-30.939003,-18.939003,-12.939003,-18.939003,-24.939003,-30.939003,-21.939003,-11.939003,0.06099701,3.060997,3.060997,7.060997,11.060997,18.060997,14.060997,8.060997,2.060997,0.06099701,0.06099701,-4.939003,-10.939003,-17.939003,-24.939003,-32.939003,-36.939003,-42.939003,-49.939003,-60.939003,-67.939,-72.939,-76.939,-80.939,-89.939,-95.939,-99.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-96.939,-95.939,-93.939,-93.939,-92.939,-86.939,-88.939,-88.939,-85.939,-66.939,-36.939003,16.060997,32.060997,11.060997,12.060997,6.060997,-20.939003,-26.939003,-24.939003,-11.939003,-6.939003,-8.939003,-0.939003,5.060997,5.060997,7.060997,11.060997,14.060997,15.060997,15.060997,12.060997,9.060997,8.060997,3.060997,-2.939003,-10.939003,-16.939003,-20.939003,-23.939003,-28.939003,-34.939003,-44.939003,-56.939003,-62.939003,-67.939,-72.939,-80.939,-88.939,-93.939,-98.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-98.939,-99.939,-99.939,-99.939,-98.939,-98.939,-92.939,-74.939,-29.939003,-19.939003,-21.939003,1.060997,24.060997,50.060997,39.060997,28.060997,26.060997,21.060997,15.060997,11.060997,7.060997,3.060997,-2.939003,-8.939003,-16.939003,-19.939003,-17.939003,-18.939003,-23.939003,-34.939003,-39.939003,-41.939003,-39.939003,-36.939003,-33.939003,-38.939003,-40.939003,-39.939003,-36.939003,-33.939003,-29.939003,-21.939003,-11.939003,-9.939003,-6.939003,-3.939003,7.060997,11.060997,1.060997,8.060997,20.060997,21.060997,27.060997,37.060997,37.060997,28.060997,3.060997,11.060997,33.060997,34.060997,32.060997,28.060997,-6.939003,-47.939003,-90.939,-103.939,-102.939,-102.939,-101.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-97.939,-95.939,-92.939,-90.939,-88.939,-85.939,-80.939,-77.939,-75.939,-73.939,-68.939,-59.939003,-51.939003,-44.939003,-39.939003,-34.939003,-30.939003,-27.939003,-27.939003,-26.939003,-17.939003,-1.939003,19.060997,36.060997,36.060997,34.060997,29.060997,-4.939003,-29.939003,-48.939003,-22.939003,-0.939003,-3.939003,-8.939003,-13.939003,-15.939003,-15.939003,-12.939003,-7.939003,-7.939003,-27.939003,-42.939003,-54.939003,-58.939003,-61.939003,-61.939003,-62.939003,-60.939003,-54.939003,-51.939003,-50.939003,-51.939003,-49.939003,-44.939003,-39.939003,-31.939003,-21.939003,-15.939003,-11.939003,-7.939003,0.06099701,11.060997,20.060997,29.060997,34.060997,36.060997,38.060997,48.060997,52.060997,47.060997,36.060997,36.060997,66.061,46.060997,12.060997,29.060997,38.060997,38.060997,37.060997,35.060997,31.060997,24.060997,15.060997,10.060997,4.060997,-1.939003,-7.939003,-14.939003,-27.939003,-33.939003,-37.939003,-42.939003,-44.939003,-42.939003,-37.939003,-35.939003,-38.939003,-37.939003,-34.939003,-35.939003,-34.939003,-33.939003,-28.939003,-23.939003,-24.939003,-22.939003,-19.939003,-10.939003,-6.939003,-6.939003,-2.939003,3.060997,13.060997,14.060997,11.060997,12.060997,13.060997,12.060997,5.060997,0.06099701,4.060997,0.06099701,-5.939003,-8.939003,-7.939003,-5.939003,-11.939003,-18.939003,-23.939003,-30.939003,-37.939003,-51.939003,-61.939003,-66.939,-42.939003,-23.939003,-29.939003,-46.939003,-62.939003,-27.939003,10.060997,53.060997,45.060997,35.060997,36.060997,29.060997,20.060997,14.060997,9.060997,7.060997,7.060997,0.06099701,-25.939003,-29.939003,-25.939003,-24.939003,-25.939003,-29.939003,-34.939003,-38.939003,-40.939003,-43.939003,-48.939003,-56.939003,-56.939003,-51.939003,-54.939003,-53.939003,-44.939003,-43.939003,-48.939003,-66.939,-62.939003,-35.939003,-24.939003,-17.939003,-18.939003,-12.939003,-4.939003,-5.939003,-6.939003,-6.939003,7.060997,10.060997,-17.939003,-32.939003,-40.939003,-25.939003,-6.939003,17.060997,11.060997,1.060997,-4.939003,-0.939003,9.060997,14.060997,17.060997,19.060997,10.060997,4.060997,0.06099701,4.060997,5.060997,-35.939003,-46.939003,-30.939003,-2.939003,19.060997,31.060997,0.06099701,-43.939003,-53.939003,-49.939003,-29.939003,-16.939003,-11.939003,-21.939003,-26.939003,-33.939003,-51.939003,-60.939003,-60.939003,-54.939003,-57.939003,-79.939,-73.939,-56.939003,-50.939003,-39.939003,-19.939003,-13.939003,-9.939003,-8.939003,-9.939003,-12.939003,-13.939003,-14.939003,-13.939003,-16.939003,-20.939003,-21.939003,-25.939003,-30.939003,-40.939003,-45.939003,-43.939003,-50.939003,-56.939003,-65.939,-72.939,-79.939,-82.939,-90.939,-100.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-93.939,-77.939,12.060997,51.060997,40.060997,-0.939003,-44.939003,-85.939,-90.939,-80.939,-78.939,-73.939,-68.939,-59.939003,-51.939003,-43.939003,-42.939003,-41.939003,-36.939003,-31.939003,-29.939003,-24.939003,-21.939003,-19.939003,-16.939003,-14.939003,-17.939003,-18.939003,-14.939003,-18.939003,-21.939003,-19.939003,-23.939003,-30.939003,-43.939003,-47.939003,-41.939003,-42.939003,-42.939003,-40.939003,-43.939003,-44.939003,-27.939003,-11.939003,3.060997,-11.939003,-19.939003,-1.939003,3.060997,5.060997,11.060997,13.060997,11.060997,-11.939003,-22.939003,-2.939003,16.060997,32.060997,20.060997,15.060997,16.060997,1.060997,-10.939003,-12.939003,-11.939003,-11.939003,-18.939003,-23.939003,-28.939003,-34.939003,-39.939003,-37.939003,-33.939003,-28.939003,-31.939003,-30.939003,-26.939003,-22.939003,-20.939003,-21.939003,-27.939003,-35.939003,-33.939003,-29.939003,-24.939003,-22.939003,-22.939003,-25.939003,-24.939003,-21.939003,-23.939003,-23.939003,-21.939003,-29.939003,-29.939003,-11.939003,-6.939003,-6.939003,-10.939003,-8.939003,-2.939003,-10.939003,-15.939003,-7.939003,1.060997,8.060997,13.060997,17.060997,18.060997,18.060997,9.060997,-15.939003,-9.939003,10.060997,24.060997,31.060997,30.060997,31.060997,32.060997,34.060997,38.060997,43.060997,49.060997,9.060997,-74.939,-92.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-100.939,-101.939,-100.939,-101.939,-101.939,-101.939,-98.939,-94.939,-77.939,-71.939,-95.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-100.939,-97.939,-94.939,-97.939,-98.939,-96.939,-96.939,-95.939,-92.939,-94.939,-95.939,-72.939,-39.939003,6.060997,7.060997,-9.939003,-51.939003,-69.939,-76.939,-75.939,-74.939,-75.939,-71.939,-68.939,-68.939,-65.939,-63.939003,-60.939003,-61.939003,-65.939,-64.939,-61.939003,-56.939003,-67.939,-82.939,-87.939,-87.939,-81.939,-86.939,-90.939,-90.939,-96.939,-102.939,-96.939,-85.939,-67.939,-50.939003,-35.939003,-24.939003,-20.939003,-19.939003,-18.939003,-17.939003,-15.939003,-14.939003,-16.939003,-21.939003,-27.939003,-32.939003,-28.939003,-29.939003,-34.939003,-34.939003,-35.939003,-42.939003,-48.939003,-55.939003,-60.939003,-55.939003,-39.939003,-51.939003,-59.939003,-48.939003,-55.939003,-70.939,-75.939,-70.939,-55.939003,-61.939003,-65.939,-64.939,-53.939003,-42.939003,-52.939003,-54.939003,-51.939003,-53.939003,-56.939003,-64.939,-65.939,-63.939003,-60.939003,-55.939003,-47.939003,-61.939003,-76.939,-85.939,-84.939,-80.939,-79.939,-79.939,-79.939,-78.939,-78.939,-20.939003,-21.939003,-22.939003,-23.939003,-24.939003,-23.939003,-23.939003,-22.939003,-19.939003,-19.939003,-19.939003,-17.939003,-14.939003,-10.939003,-18.939003,-26.939003,-30.939003,-31.939003,-31.939003,-37.939003,-42.939003,-45.939003,-51.939003,-55.939003,-56.939003,-57.939003,-59.939003,-63.939003,-63.939003,-58.939003,-68.939,-77.939,-77.939,-80.939,-83.939,-86.939,-90.939,-95.939,-96.939,-98.939,-102.939,-98.939,-88.939,-35.939003,-28.939003,-67.939,-88.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-100.939,-99.939,-97.939,-93.939,-90.939,-87.939,-82.939,-78.939,-72.939,-66.939,-60.939003,-54.939003,-49.939003,-44.939003,-40.939003,-36.939003,-32.939003,-26.939003,-22.939003,-20.939003,-19.939003,-15.939003,-10.939003,-8.939003,-8.939003,-8.939003,-10.939003,-11.939003,-11.939003,-10.939003,-11.939003,-7.939003,-1.939003,-16.939003,-30.939003,-34.939003,-37.939003,-41.939003,-43.939003,-45.939003,-46.939003,-51.939003,-57.939003,-63.939003,-66.939,-68.939,-70.939,-72.939,-74.939,-77.939,-79.939,-81.939,-83.939,-85.939,-89.939,-91.939,-93.939,-93.939,-95.939,-98.939,-100.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-102.939,-102.939,-103.939,-101.939,-100.939,-100.939,-95.939,-91.939,-88.939,-85.939,-82.939,-75.939,-70.939,-67.939,-63.939003,-58.939003,-49.939003,-44.939003,-38.939003,-34.939003,-28.939003,-21.939003,-23.939003,-21.939003,-14.939003,-12.939003,-11.939003,-16.939003,-18.939003,-18.939003,-18.939003,-16.939003,-9.939003,-9.939003,-13.939003,-14.939003,-17.939003,-19.939003,-24.939003,-29.939003,-35.939003,-39.939003,-43.939003,-46.939003,-49.939003,-53.939003,-57.939003,-61.939003,-64.939,-67.939,-69.939,-72.939,-74.939,-75.939,-75.939,-77.939,-80.939,-83.939,-86.939,-89.939,-90.939,-91.939,-95.939,-98.939,-98.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-99.939,-97.939,-97.939,-95.939,-91.939,-85.939,-83.939,-80.939,-73.939,-68.939,-62.939003,-56.939003,-51.939003,-48.939003,-41.939003,-34.939003,-24.939003,-16.939003,-10.939003,-10.939003,-12.939003,-16.939003,-18.939003,-20.939003,-22.939003,-21.939003,-19.939003,-20.939003,-21.939003,-21.939003,-17.939003,-15.939003,-16.939003,-15.939003,-12.939003,-13.939003,-8.939003,1.060997,10.060997,15.060997,13.060997,17.060997,24.060997,22.060997,22.060997,24.060997,28.060997,31.060997,35.060997,40.060997,45.060997,44.060997,44.060997,44.060997,53.060997,51.060997,20.060997,33.060997,62.060997,60.060997,62.060997,66.061,64.061,48.060997,2.060997,6.060997,28.060997,29.060997,27.060997,23.060997,-7.939003,-44.939003,-89.939,-100.939,-96.939,-90.939,-86.939,-86.939,-79.939,-73.939,-71.939,-67.939,-64.939,-60.939003,-53.939003,-43.939003,-42.939003,-41.939003,-40.939003,-36.939003,-33.939003,-31.939003,-27.939003,-23.939003,-23.939003,-22.939003,-19.939003,-18.939003,-19.939003,-20.939003,-20.939003,-17.939003,-21.939003,-25.939003,-27.939003,-31.939003,-36.939003,-44.939003,-37.939003,-15.939003,12.060997,35.060997,38.060997,35.060997,26.060997,-1.939003,-34.939003,-73.939,-48.939003,-26.939003,-30.939003,-36.939003,-38.939003,-0.939003,26.060997,45.060997,46.060997,37.060997,12.060997,-0.939003,-5.939003,5.060997,11.060997,13.060997,14.060997,16.060997,20.060997,23.060997,26.060997,28.060997,30.060997,31.060997,31.060997,31.060997,29.060997,28.060997,26.060997,24.060997,22.060997,20.060997,13.060997,8.060997,10.060997,6.060997,0.06099701,1.060997,-1.939003,-8.939003,-14.939003,-14.939003,-4.939003,-9.939003,-19.939003,-18.939003,-17.939003,-16.939003,-14.939003,-11.939003,-9.939003,-6.939003,-3.939003,-3.939003,-4.939003,-4.939003,-1.939003,0.06099701,-5.939003,-14.939003,-24.939003,-22.939003,-19.939003,-16.939003,-10.939003,-5.939003,-3.939003,-4.939003,-5.939003,-4.939003,-7.939003,-13.939003,-15.939003,-16.939003,-20.939003,-23.939003,-28.939003,-31.939003,-36.939003,-43.939003,-43.939003,-43.939003,-45.939003,-47.939003,-51.939003,-55.939003,-59.939003,-64.939,-67.939,-68.939,-67.939,-68.939,-70.939,-71.939,-71.939,-70.939,-72.939,-75.939,-75.939,-76.939,-76.939,-78.939,-76.939,-71.939,-49.939003,-32.939003,-34.939003,-41.939003,-47.939003,-32.939003,-20.939003,-8.939003,-13.939003,-20.939003,-23.939003,-27.939003,-30.939003,-34.939003,-36.939003,-35.939003,-33.939003,-32.939003,-33.939003,-31.939003,-29.939003,-29.939003,-31.939003,-34.939003,-32.939003,-29.939003,-29.939003,-29.939003,-29.939003,-26.939003,-21.939003,-14.939003,-13.939003,-10.939003,5.060997,-13.939003,-44.939003,-56.939003,-53.939003,-34.939003,-24.939003,-16.939003,-13.939003,-3.939003,4.060997,-18.939003,-22.939003,-8.939003,8.060997,9.060997,-32.939003,-24.939003,0.06099701,-9.939003,-7.939003,5.060997,0.06099701,-4.939003,0.06099701,11.060997,24.060997,17.060997,8.060997,-3.939003,-15.939003,-19.939003,-5.939003,9.060997,20.060997,-0.939003,-17.939003,-30.939003,-23.939003,-7.939003,28.060997,19.060997,-6.939003,-14.939003,-14.939003,-6.939003,-20.939003,-32.939003,-37.939003,-41.939003,-47.939003,-56.939003,-53.939003,-35.939003,-40.939003,-52.939003,-74.939,-67.939,-47.939003,-15.939003,-15.939003,-46.939003,-54.939003,-57.939003,-58.939003,-62.939003,-67.939,-68.939,-70.939,-71.939,-73.939,-75.939,-76.939,-77.939,-79.939,-82.939,-84.939,-83.939,-85.939,-87.939,-90.939,-93.939,-95.939,-96.939,-99.939,-101.939,-97.939,-93.939,-92.939,-88.939,-83.939,-78.939,-73.939,-69.939,-67.939,-64.939,-60.939003,-54.939003,-46.939003,-28.939003,-23.939003,-33.939003,-32.939003,-33.939003,-36.939003,-35.939003,-31.939003,-34.939003,-34.939003,-33.939003,-30.939003,-28.939003,-26.939003,-31.939003,-38.939003,-35.939003,-34.939003,-38.939003,-38.939003,-40.939003,-48.939003,-51.939003,-53.939003,-60.939003,-64.939,-64.939,-68.939,-72.939,-74.939,-71.939,-64.939,-50.939003,-44.939003,-42.939003,-40.939003,-39.939003,-39.939003,-48.939003,-57.939003,-30.939003,-3.939003,22.060997,-4.939003,-23.939003,-11.939003,-8.939003,-9.939003,-12.939003,-16.939003,-20.939003,-30.939003,-35.939003,-28.939003,-29.939003,-34.939003,-36.939003,-38.939003,-41.939003,-47.939003,-51.939003,-51.939003,-51.939003,-51.939003,-50.939003,-48.939003,-47.939003,-47.939003,-44.939003,-34.939003,-32.939003,-33.939003,-31.939003,-31.939003,-32.939003,-29.939003,-27.939003,-28.939003,-32.939003,-38.939003,-34.939003,-29.939003,-24.939003,-21.939003,-19.939003,-22.939003,-22.939003,-21.939003,-22.939003,-22.939003,-21.939003,-26.939003,-31.939003,-32.939003,-36.939003,-41.939003,-43.939003,-43.939003,-41.939003,-43.939003,-44.939003,-44.939003,-41.939003,-40.939003,-39.939003,-37.939003,-34.939003,-34.939003,-35.939003,-39.939003,-39.939003,-36.939003,-35.939003,-32.939003,-28.939003,-31.939003,-32.939003,-29.939003,-25.939003,-21.939003,-19.939003,-28.939003,-47.939003,-55.939003,-60.939003,-61.939003,-64.939,-66.939,-69.939,-71.939,-69.939,-75.939,-79.939,-75.939,-77.939,-79.939,-81.939,-81.939,-80.939,-66.939,-62.939003,-82.939,-89.939,-90.939,-90.939,-91.939,-92.939,-94.939,-95.939,-94.939,-92.939,-91.939,-95.939,-97.939,-97.939,-98.939,-98.939,-96.939,-98.939,-99.939,-76.939,-35.939003,24.060997,32.060997,12.060997,-55.939003,-84.939,-94.939,-94.939,-93.939,-94.939,-92.939,-91.939,-91.939,-90.939,-90.939,-89.939,-88.939,-88.939,-86.939,-80.939,-69.939,-74.939,-86.939,-91.939,-90.939,-84.939,-87.939,-92.939,-93.939,-97.939,-101.939,-100.939,-94.939,-82.939,-63.939003,-46.939003,-36.939003,-30.939003,-26.939003,-23.939003,-22.939003,-22.939003,-21.939003,-23.939003,-26.939003,-29.939003,-32.939003,-30.939003,-30.939003,-29.939003,-33.939003,-36.939003,-39.939003,-47.939003,-56.939003,-56.939003,-50.939003,-38.939003,-50.939003,-56.939003,-47.939003,-58.939003,-77.939,-82.939,-77.939,-65.939,-67.939,-66.939,-60.939003,-49.939003,-38.939003,-40.939003,-47.939003,-61.939003,-58.939003,-55.939003,-56.939003,-65.939,-76.939,-71.939,-62.939003,-51.939003,-72.939,-85.939,-74.939,-68.939,-66.939,-73.939,-74.939,-71.939,-73.939,-75.939,-18.939003,-17.939003,-14.939003,-12.939003,-9.939003,-5.939003,-2.939003,0.06099701,6.060997,9.060997,12.060997,17.060997,23.060997,31.060997,2.060997,-25.939003,-43.939003,-50.939003,-51.939003,-63.939003,-69.939,-70.939,-74.939,-78.939,-79.939,-79.939,-78.939,-82.939,-80.939,-72.939,-80.939,-88.939,-89.939,-91.939,-93.939,-95.939,-95.939,-96.939,-94.939,-92.939,-93.939,-91.939,-83.939,-38.939003,-26.939003,-50.939003,-69.939,-81.939,-77.939,-75.939,-74.939,-70.939,-67.939,-65.939,-63.939003,-60.939003,-56.939003,-54.939003,-52.939003,-47.939003,-42.939003,-37.939003,-32.939003,-27.939003,-24.939003,-20.939003,-16.939003,-18.939003,-16.939003,-13.939003,-11.939003,-10.939003,-10.939003,-13.939003,-15.939003,-13.939003,-13.939003,-16.939003,-20.939003,-22.939003,-16.939003,3.060997,27.060997,21.060997,27.060997,44.060997,-3.939003,-46.939003,-64.939,-72.939,-75.939,-79.939,-82.939,-86.939,-90.939,-94.939,-98.939,-101.939,-103.939,-103.939,-103.939,-102.939,-103.939,-103.939,-102.939,-102.939,-101.939,-100.939,-98.939,-95.939,-94.939,-93.939,-91.939,-88.939,-86.939,-84.939,-83.939,-82.939,-79.939,-76.939,-73.939,-71.939,-70.939,-66.939,-64.939,-63.939003,-58.939003,-53.939003,-49.939003,-46.939003,-44.939003,-38.939003,-35.939003,-34.939003,-30.939003,-25.939003,-21.939003,-19.939003,-16.939003,-13.939003,-9.939003,-6.939003,-8.939003,-10.939003,-10.939003,-12.939003,-14.939003,-9.939003,-4.939003,1.060997,5.060997,4.060997,-9.939003,-23.939003,-37.939003,-43.939003,-47.939003,-49.939003,-57.939003,-64.939,-70.939,-75.939,-81.939,-84.939,-89.939,-93.939,-95.939,-98.939,-102.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-100.939,-100.939,-98.939,-94.939,-93.939,-91.939,-88.939,-87.939,-85.939,-83.939,-81.939,-79.939,-78.939,-76.939,-74.939,-71.939,-69.939,-63.939003,-61.939003,-59.939003,-57.939003,-53.939003,-46.939003,-44.939003,-41.939003,-36.939003,-32.939003,-28.939003,-23.939003,-20.939003,-17.939003,-16.939003,-17.939003,-21.939003,-18.939003,-13.939003,-15.939003,-17.939003,-20.939003,-18.939003,-16.939003,-16.939003,-12.939003,-6.939003,-4.939003,-2.939003,-2.939003,6.060997,10.060997,-0.939003,-2.939003,-0.939003,0.06099701,11.060997,34.060997,52.060997,61.060997,52.060997,53.060997,57.060997,60.060997,61.060997,63.060997,65.061,67.061,68.061,71.061,74.061,70.061,67.061,63.060997,69.061,62.060997,23.060997,35.060997,64.061,61.060997,58.060997,56.060997,53.060997,38.060997,0.06099701,5.060997,25.060997,27.060997,26.060997,23.060997,-8.939003,-42.939003,-70.939,-70.939,-60.939003,-58.939003,-56.939003,-53.939003,-46.939003,-40.939003,-41.939003,-38.939003,-36.939003,-35.939003,-30.939003,-21.939003,-22.939003,-24.939003,-23.939003,-22.939003,-21.939003,-21.939003,-19.939003,-15.939003,-17.939003,-18.939003,-16.939003,-15.939003,-16.939003,-18.939003,-17.939003,-14.939003,-18.939003,-22.939003,-23.939003,-24.939003,-26.939003,-34.939003,-30.939003,-14.939003,10.060997,31.060997,40.060997,36.060997,25.060997,-0.939003,-33.939003,-75.939,-33.939003,0.06099701,-6.939003,-12.939003,-14.939003,19.060997,44.060997,60.060997,58.060997,49.060997,30.060997,20.060997,16.060997,26.060997,30.060997,30.060997,29.060997,30.060997,30.060997,32.060997,33.060997,33.060997,33.060997,32.060997,32.060997,30.060997,25.060997,21.060997,18.060997,16.060997,13.060997,10.060997,2.060997,-4.939003,-4.939003,-7.939003,-12.939003,-14.939003,-17.939003,-23.939003,-24.939003,-23.939003,-22.939003,-23.939003,-27.939003,-29.939003,-30.939003,-29.939003,-28.939003,-27.939003,-25.939003,-22.939003,-19.939003,-16.939003,-15.939003,-15.939003,-15.939003,-13.939003,-10.939003,3.060997,20.060997,18.060997,15.060997,8.060997,-11.939003,-25.939003,-22.939003,-23.939003,-25.939003,-25.939003,-27.939003,-34.939003,-38.939003,-41.939003,-45.939003,-48.939003,-53.939003,-60.939003,-67.939,-73.939,-72.939,-71.939,-74.939,-74.939,-74.939,-78.939,-81.939,-83.939,-81.939,-79.939,-77.939,-75.939,-73.939,-71.939,-70.939,-67.939,-67.939,-67.939,-65.939,-63.939003,-62.939003,-57.939003,-52.939003,-45.939003,-36.939003,-31.939003,-32.939003,-34.939003,-34.939003,-29.939003,-29.939003,-32.939003,-33.939003,-34.939003,-37.939003,-37.939003,-35.939003,-35.939003,-34.939003,-32.939003,-26.939003,-22.939003,-22.939003,-20.939003,-18.939003,-12.939003,-7.939003,-2.939003,0.06099701,2.060997,0.06099701,1.060997,2.060997,7.060997,11.060997,15.060997,19.060997,24.060997,36.060997,-0.939003,-50.939003,-50.939003,-45.939003,-33.939003,-32.939003,-27.939003,-11.939003,-0.939003,3.060997,-30.939003,-31.939003,0.06099701,10.060997,4.060997,-34.939003,-16.939003,18.060997,6.060997,-0.939003,-2.939003,-4.939003,-4.939003,-0.939003,14.060997,31.060997,22.060997,8.060997,-12.939003,-28.939003,-31.939003,-8.939003,9.060997,22.060997,17.060997,5.060997,-14.939003,-21.939003,-16.939003,17.060997,21.060997,13.060997,11.060997,9.060997,7.060997,-23.939003,-44.939003,-41.939003,-46.939003,-54.939003,-52.939003,-44.939003,-30.939003,-39.939003,-53.939003,-70.939,-65.939,-47.939003,5.060997,-2.939003,-70.939,-82.939,-88.939,-89.939,-91.939,-94.939,-93.939,-93.939,-93.939,-94.939,-94.939,-91.939,-88.939,-85.939,-83.939,-82.939,-81.939,-80.939,-80.939,-78.939,-76.939,-75.939,-75.939,-75.939,-74.939,-69.939,-64.939,-61.939003,-56.939003,-51.939003,-47.939003,-43.939003,-40.939003,-39.939003,-38.939003,-38.939003,-36.939003,-34.939003,-39.939003,-43.939003,-46.939003,-39.939003,-34.939003,-34.939003,-35.939003,-34.939003,-39.939003,-42.939003,-42.939003,-43.939003,-43.939003,-44.939003,-49.939003,-57.939003,-56.939003,-57.939003,-62.939003,-63.939003,-66.939,-74.939,-78.939,-81.939,-85.939,-87.939,-86.939,-89.939,-91.939,-91.939,-84.939,-73.939,-51.939003,-40.939003,-40.939003,-38.939003,-37.939003,-38.939003,-45.939003,-53.939003,-33.939003,-15.939003,-1.939003,-17.939003,-31.939003,-28.939003,-28.939003,-30.939003,-35.939003,-38.939003,-40.939003,-41.939003,-41.939003,-40.939003,-44.939003,-50.939003,-49.939003,-48.939003,-49.939003,-49.939003,-49.939003,-49.939003,-45.939003,-41.939003,-38.939003,-34.939003,-30.939003,-23.939003,-19.939003,-25.939003,-37.939003,-48.939003,-43.939003,-40.939003,-40.939003,-39.939003,-38.939003,-37.939003,-36.939003,-37.939003,-34.939003,-30.939003,-24.939003,-20.939003,-18.939003,-21.939003,-22.939003,-22.939003,-21.939003,-21.939003,-23.939003,-27.939003,-27.939003,-16.939003,-19.939003,-27.939003,-26.939003,-27.939003,-27.939003,-29.939003,-31.939003,-33.939003,-31.939003,-30.939003,-31.939003,-32.939003,-32.939003,-33.939003,-36.939003,-42.939003,-41.939003,-36.939003,-38.939003,-38.939003,-36.939003,-39.939003,-42.939003,-39.939003,-39.939003,-39.939003,-38.939003,-38.939003,-40.939003,-44.939003,-48.939003,-49.939003,-51.939003,-53.939003,-56.939003,-58.939003,-56.939003,-61.939003,-64.939,-60.939003,-61.939003,-62.939003,-63.939003,-63.939003,-64.939,-55.939003,-52.939003,-65.939,-68.939,-68.939,-67.939,-68.939,-72.939,-73.939,-73.939,-71.939,-70.939,-71.939,-74.939,-76.939,-76.939,-77.939,-78.939,-78.939,-78.939,-78.939,-66.939,-38.939003,3.060997,13.060997,0.06099701,-52.939003,-75.939,-83.939,-85.939,-85.939,-85.939,-85.939,-85.939,-87.939,-88.939,-89.939,-89.939,-90.939,-92.939,-90.939,-84.939,-72.939,-77.939,-89.939,-95.939,-94.939,-87.939,-87.939,-90.939,-92.939,-97.939,-101.939,-99.939,-90.939,-74.939,-64.939,-52.939003,-39.939003,-31.939003,-26.939003,-23.939003,-22.939003,-22.939003,-23.939003,-26.939003,-28.939003,-29.939003,-30.939003,-31.939003,-29.939003,-26.939003,-30.939003,-33.939003,-37.939003,-44.939003,-53.939003,-51.939003,-49.939003,-46.939003,-52.939003,-53.939003,-46.939003,-57.939003,-75.939,-80.939,-78.939,-70.939,-71.939,-70.939,-61.939003,-53.939003,-45.939003,-43.939003,-48.939003,-60.939003,-57.939003,-55.939003,-58.939003,-68.939,-81.939,-74.939,-67.939,-59.939003,-75.939,-83.939,-65.939,-60.939003,-62.939003,-72.939,-74.939,-68.939,-72.939,-76.939,48.060997,51.060997,55.060997,60.060997,63.060997,66.061,67.061,67.061,70.061,70.061,70.061,70.061,72.061,78.061,29.060997,-19.939003,-52.939003,-63.939003,-63.939003,-76.939,-79.939,-73.939,-70.939,-71.939,-74.939,-73.939,-71.939,-71.939,-69.939,-65.939,-68.939,-73.939,-78.939,-80.939,-81.939,-86.939,-86.939,-81.939,-75.939,-71.939,-73.939,-71.939,-64.939,-37.939003,-25.939003,-26.939003,-34.939003,-36.939003,-25.939003,-20.939003,-16.939003,-7.939003,-2.939003,0.06099701,5.060997,8.060997,8.060997,6.060997,3.060997,5.060997,5.060997,4.060997,1.060997,-2.939003,-6.939003,-10.939003,-15.939003,-25.939003,-31.939003,-36.939003,-44.939003,-51.939003,-55.939003,-65.939,-74.939,-78.939,-82.939,-83.939,-90.939,-84.939,-45.939003,25.060997,100.061,78.061,81.061,105.061,19.060997,-58.939003,-90.939,-100.939,-100.939,-100.939,-100.939,-101.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-102.939,-102.939,-101.939,-100.939,-98.939,-93.939,-87.939,-80.939,-78.939,-74.939,-66.939,-59.939003,-52.939003,-49.939003,-45.939003,-40.939003,-32.939003,-23.939003,-16.939003,-9.939003,-3.939003,1.060997,4.060997,6.060997,8.060997,11.060997,14.060997,15.060997,14.060997,9.060997,4.060997,0.06099701,0.06099701,-2.939003,-14.939003,-21.939003,-26.939003,-31.939003,-36.939003,-41.939003,-45.939003,-54.939003,-72.939,-67.939,-45.939003,35.060997,73.061,71.061,85.061,69.061,-21.939003,-69.939,-97.939,-98.939,-98.939,-98.939,-99.939,-99.939,-99.939,-100.939,-101.939,-101.939,-102.939,-102.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-100.939,-97.939,-95.939,-89.939,-80.939,-76.939,-70.939,-62.939003,-56.939003,-50.939003,-46.939003,-38.939003,-31.939003,-29.939003,-23.939003,-15.939003,-8.939003,-2.939003,4.060997,7.060997,9.060997,12.060997,13.060997,14.060997,15.060997,14.060997,14.060997,8.060997,2.060997,-0.939003,-4.939003,-6.939003,-16.939003,-22.939003,-20.939003,-25.939003,-32.939003,-13.939003,9.060997,37.060997,39.060997,41.060997,46.060997,50.060997,54.060997,60.060997,62.060997,60.060997,71.061,67.061,31.060997,19.060997,18.060997,23.060997,37.060997,64.061,85.061,94.061,79.061,70.061,65.061,75.061,78.061,76.061,75.061,74.061,71.061,71.061,73.061,68.061,62.060997,53.060997,54.060997,45.060997,10.060997,12.060997,27.060997,22.060997,15.060997,7.060997,2.060997,-1.939003,-4.939003,9.060997,25.060997,28.060997,29.060997,28.060997,-11.939003,-39.939003,-32.939003,-13.939003,5.060997,-7.939003,-11.939003,-4.939003,-1.939003,-3.939003,-10.939003,-13.939003,-17.939003,-22.939003,-27.939003,-31.939003,-39.939003,-46.939003,-48.939003,-52.939003,-56.939003,-61.939003,-63.939003,-60.939003,-62.939003,-64.939,-66.939,-64.939,-61.939003,-52.939003,-42.939003,-34.939003,-29.939003,-23.939003,-16.939003,-7.939003,0.06099701,3.060997,4.060997,0.06099701,12.060997,26.060997,43.060997,39.060997,25.060997,0.06099701,-26.939003,-54.939003,21.060997,79.061,68.061,62.060997,56.060997,45.060997,37.060997,32.060997,28.060997,26.060997,27.060997,21.060997,11.060997,3.060997,-4.939003,-10.939003,-15.939003,-20.939003,-24.939003,-26.939003,-30.939003,-36.939003,-40.939003,-41.939003,-38.939003,-36.939003,-35.939003,-35.939003,-35.939003,-31.939003,-26.939003,-19.939003,-12.939003,-8.939003,-9.939003,-6.939003,-1.939003,2.060997,3.060997,1.060997,4.060997,9.060997,13.060997,3.060997,-10.939003,-4.939003,-1.939003,-0.939003,-4.939003,-10.939003,-16.939003,-23.939003,-30.939003,-28.939003,-30.939003,-34.939003,-49.939003,-55.939003,-43.939003,21.060997,97.061,81.061,60.060997,33.060997,-40.939003,-96.939,-96.939,-96.939,-96.939,-96.939,-96.939,-97.939,-97.939,-97.939,-98.939,-97.939,-96.939,-98.939,-98.939,-94.939,-88.939,-81.939,-73.939,-65.939,-59.939003,-56.939003,-52.939003,-44.939003,-37.939003,-32.939003,-26.939003,-19.939003,-12.939003,-8.939003,-3.939003,4.060997,4.060997,4.060997,6.060997,7.060997,6.060997,9.060997,11.060997,9.060997,-5.939003,-18.939003,-23.939003,-23.939003,-22.939003,-17.939003,-16.939003,-19.939003,-12.939003,-7.939003,-6.939003,-1.939003,6.060997,13.060997,16.060997,14.060997,28.060997,32.060997,6.060997,2.060997,8.060997,26.060997,46.060997,66.061,65.061,58.060997,51.060997,48.060997,49.060997,45.060997,42.060997,38.060997,44.060997,49.060997,47.060997,-3.939003,-65.939,-49.939003,-39.939003,-34.939003,-48.939003,-49.939003,-13.939003,-1.939003,-5.939003,-42.939003,-33.939003,21.060997,12.060997,-4.939003,-21.939003,-9.939003,14.060997,21.060997,14.060997,-7.939003,-2.939003,0.06099701,-7.939003,8.060997,31.060997,30.060997,17.060997,-8.939003,-27.939003,-32.939003,-8.939003,4.060997,12.060997,19.060997,20.060997,18.060997,3.060997,-6.939003,-2.939003,6.060997,16.060997,24.060997,22.060997,10.060997,-24.939003,-47.939003,-34.939003,-41.939003,-54.939003,-38.939003,-35.939003,-45.939003,-52.939003,-58.939003,-65.939,-67.939,-58.939003,10.060997,0.06099701,-89.939,-99.939,-101.939,-100.939,-96.939,-92.939,-88.939,-83.939,-78.939,-78.939,-75.939,-66.939,-58.939003,-48.939003,-42.939003,-39.939003,-38.939003,-35.939003,-33.939003,-28.939003,-23.939003,-20.939003,-19.939003,-18.939003,-18.939003,-16.939003,-14.939003,-8.939003,-6.939003,-6.939003,-6.939003,-9.939003,-15.939003,-18.939003,-23.939003,-34.939003,-40.939003,-39.939003,-21.939003,-7.939003,2.060997,-21.939003,-49.939003,-79.939,-88.939,-89.939,-94.939,-97.939,-97.939,-98.939,-98.939,-98.939,-98.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-100.939,-94.939,-88.939,-80.939,-80.939,-78.939,-68.939,-61.939003,-56.939003,-44.939003,-36.939003,-35.939003,-37.939003,-38.939003,-35.939003,-33.939003,-32.939003,-36.939003,-47.939003,-66.939,-52.939003,-43.939003,-51.939003,-55.939003,-57.939003,-56.939003,-53.939003,-48.939003,-43.939003,-39.939003,-37.939003,-28.939003,-16.939003,-19.939003,-15.939003,-6.939003,-4.939003,-4.939003,-4.939003,5.060997,18.060997,17.060997,19.060997,22.060997,36.060997,33.060997,-10.939003,-46.939003,-73.939,-66.939,-58.939003,-51.939003,-51.939003,-51.939003,-49.939003,-41.939003,-34.939003,-35.939003,-31.939003,-23.939003,-20.939003,-19.939003,-23.939003,-25.939003,-25.939003,-20.939003,-20.939003,-27.939003,-30.939003,-16.939003,35.060997,45.060997,37.060997,38.060997,38.060997,40.060997,29.060997,21.060997,25.060997,31.060997,36.060997,35.060997,32.060997,24.060997,19.060997,6.060997,-25.939003,-15.939003,10.060997,15.060997,13.060997,7.060997,5.060997,4.060997,4.060997,-3.939003,-10.939003,-7.939003,-21.939003,-53.939003,-61.939003,-65.939,-63.939003,-62.939003,-62.939003,-61.939003,-61.939003,-60.939003,-58.939003,-57.939003,-55.939003,-53.939003,-50.939003,-47.939003,-45.939003,-45.939003,-43.939003,-43.939003,-43.939003,-40.939003,-35.939003,-32.939003,-35.939003,-43.939003,-39.939003,-35.939003,-31.939003,-31.939003,-32.939003,-35.939003,-36.939003,-35.939003,-34.939003,-34.939003,-37.939003,-35.939003,-33.939003,-41.939003,-49.939003,-55.939003,-51.939003,-46.939003,-43.939003,-42.939003,-43.939003,-48.939003,-50.939003,-50.939003,-49.939003,-50.939003,-54.939003,-58.939003,-60.939003,-61.939003,-67.939,-77.939,-76.939,-72.939,-66.939,-74.939,-90.939,-98.939,-98.939,-91.939,-86.939,-85.939,-89.939,-96.939,-102.939,-92.939,-73.939,-45.939003,-53.939003,-55.939003,-34.939003,-24.939003,-20.939003,-16.939003,-15.939003,-17.939003,-20.939003,-25.939003,-28.939003,-28.939003,-27.939003,-30.939003,-28.939003,-23.939003,-23.939003,-26.939003,-35.939003,-41.939003,-47.939003,-46.939003,-51.939003,-63.939003,-57.939003,-49.939003,-44.939003,-52.939003,-65.939,-70.939,-72.939,-70.939,-74.939,-75.939,-68.939,-66.939,-64.939,-62.939003,-56.939003,-47.939003,-50.939003,-56.939003,-70.939,-75.939,-76.939,-69.939,-68.939,-72.939,-72.939,-69.939,-60.939003,-61.939003,-68.939,-77.939,-77.939,-70.939,-77.939,-82.939,67.061,67.061,67.061,67.061,69.061,72.061,72.061,72.061,70.061,71.061,73.061,71.061,74.061,81.061,40.060997,-4.939003,-50.939003,-65.939,-66.939,-72.939,-73.939,-67.939,-64.939,-63.939003,-65.939,-62.939003,-57.939003,-54.939003,-50.939003,-44.939003,-43.939003,-43.939003,-42.939003,-42.939003,-41.939003,-39.939003,-35.939003,-29.939003,-26.939003,-22.939003,-19.939003,-19.939003,-21.939003,-19.939003,-17.939003,-15.939003,-20.939003,-24.939003,-23.939003,-22.939003,-22.939003,-22.939003,-24.939003,-28.939003,-30.939003,-32.939003,-37.939003,-42.939003,-46.939003,-47.939003,-49.939003,-52.939003,-55.939003,-57.939003,-60.939003,-62.939003,-64.939,-69.939,-72.939,-75.939,-78.939,-82.939,-84.939,-88.939,-93.939,-95.939,-96.939,-96.939,-100.939,-93.939,-64.939,15.060997,104.061,81.061,81.061,101.061,24.060997,-49.939003,-91.939,-97.939,-91.939,-90.939,-88.939,-86.939,-83.939,-80.939,-78.939,-73.939,-68.939,-63.939003,-59.939003,-55.939003,-52.939003,-49.939003,-44.939003,-41.939003,-39.939003,-33.939003,-29.939003,-25.939003,-24.939003,-23.939003,-18.939003,-17.939003,-16.939003,-19.939003,-20.939003,-21.939003,-19.939003,-17.939003,-18.939003,-20.939003,-22.939003,-22.939003,-25.939003,-29.939003,-34.939003,-38.939003,-40.939003,-40.939003,-41.939003,-48.939003,-53.939003,-56.939003,-56.939003,-57.939003,-62.939003,-66.939,-69.939,-71.939,-74.939,-77.939,-79.939,-83.939,-92.939,-87.939,-67.939,42.060997,91.061,80.061,98.061,86.061,-2.939003,-61.939003,-103.939,-101.939,-101.939,-101.939,-98.939,-96.939,-94.939,-91.939,-87.939,-81.939,-77.939,-75.939,-71.939,-68.939,-63.939003,-58.939003,-53.939003,-50.939003,-45.939003,-40.939003,-36.939003,-33.939003,-29.939003,-26.939003,-22.939003,-23.939003,-21.939003,-16.939003,-15.939003,-14.939003,-12.939003,-14.939003,-17.939003,-17.939003,-16.939003,-15.939003,-18.939003,-20.939003,-20.939003,-22.939003,-25.939003,-26.939003,-31.939003,-39.939003,-41.939003,-43.939003,-42.939003,-47.939003,-52.939003,-56.939003,-59.939003,-60.939003,-65.939,-59.939003,-29.939003,-27.939003,-34.939003,-8.939003,26.060997,71.061,69.061,64.061,66.061,67.061,68.061,71.061,69.061,63.060997,74.061,72.061,39.060997,15.060997,0.06099701,15.060997,30.060997,45.060997,59.060997,66.061,53.060997,45.060997,39.060997,37.060997,34.060997,31.060997,29.060997,26.060997,18.060997,14.060997,12.060997,7.060997,4.060997,0.06099701,-1.939003,-5.939003,-12.939003,-14.939003,-14.939003,-15.939003,-16.939003,-16.939003,-20.939003,-20.939003,-5.939003,13.060997,31.060997,32.060997,30.060997,24.060997,-10.939003,-41.939003,-57.939003,-47.939003,-30.939003,-36.939003,-39.939003,-36.939003,-35.939003,-34.939003,-35.939003,-32.939003,-27.939003,-27.939003,-26.939003,-25.939003,-20.939003,-15.939003,-9.939003,-7.939003,-6.939003,-3.939003,2.060997,11.060997,12.060997,13.060997,15.060997,16.060997,16.060997,18.060997,20.060997,21.060997,22.060997,24.060997,27.060997,27.060997,26.060997,21.060997,11.060997,-3.939003,11.060997,28.060997,40.060997,36.060997,24.060997,1.060997,-26.939003,-59.939003,-13.939003,21.060997,12.060997,9.060997,9.060997,0.06099701,-7.939003,-13.939003,-16.939003,-17.939003,-11.939003,-7.939003,-4.939003,-6.939003,-5.939003,-1.939003,-3.939003,-5.939003,-6.939003,-6.939003,-6.939003,-10.939003,-13.939003,-15.939003,-16.939003,-18.939003,-20.939003,-21.939003,-22.939003,-19.939003,-19.939003,-23.939003,-24.939003,-27.939003,-30.939003,-33.939003,-35.939003,-41.939003,-42.939003,-38.939003,-34.939003,-35.939003,-46.939003,-49.939003,-49.939003,-54.939003,-57.939003,-57.939003,-59.939003,-61.939003,-64.939,-68.939,-71.939,-71.939,-71.939,-74.939,-81.939,-79.939,-58.939003,2.060997,68.061,47.060997,25.060997,2.060997,-45.939003,-78.939,-66.939,-61.939003,-58.939003,-55.939003,-52.939003,-50.939003,-46.939003,-43.939003,-41.939003,-40.939003,-38.939003,-37.939003,-35.939003,-30.939003,-28.939003,-26.939003,-25.939003,-21.939003,-19.939003,-20.939003,-19.939003,-15.939003,-16.939003,-17.939003,-18.939003,-16.939003,-14.939003,-18.939003,-21.939003,-22.939003,-25.939003,-28.939003,-32.939003,-37.939003,-43.939003,-45.939003,-47.939003,-48.939003,-33.939003,-21.939003,-24.939003,-34.939003,-45.939003,-22.939003,3.060997,33.060997,33.060997,31.060997,30.060997,27.060997,23.060997,29.060997,30.060997,25.060997,34.060997,33.060997,7.060997,-3.939003,-3.939003,13.060997,27.060997,38.060997,32.060997,24.060997,16.060997,11.060997,9.060997,5.060997,1.060997,-3.939003,-3.939003,-3.939003,-5.939003,-30.939003,-60.939003,-39.939003,-32.939003,-38.939003,-50.939003,-46.939003,-3.939003,4.060997,-5.939003,-31.939003,-22.939003,21.060997,4.060997,-13.939003,-18.939003,-19.939003,-16.939003,1.060997,10.060997,9.060997,6.060997,3.060997,4.060997,18.060997,34.060997,24.060997,10.060997,-7.939003,-6.939003,-2.939003,2.060997,14.060997,28.060997,25.060997,21.060997,19.060997,15.060997,13.060997,16.060997,11.060997,5.060997,8.060997,1.060997,-16.939003,-24.939003,-26.939003,-16.939003,-26.939003,-41.939003,-24.939003,-24.939003,-42.939003,-48.939003,-53.939003,-61.939003,-70.939,-74.939,-27.939003,-20.939003,-53.939003,-53.939003,-51.939003,-48.939003,-43.939003,-39.939003,-39.939003,-39.939003,-36.939003,-37.939003,-36.939003,-31.939003,-30.939003,-28.939003,-30.939003,-31.939003,-33.939003,-33.939003,-32.939003,-33.939003,-35.939003,-38.939003,-42.939003,-45.939003,-49.939003,-53.939003,-56.939003,-54.939003,-56.939003,-57.939003,-58.939003,-61.939003,-63.939003,-65.939,-68.939,-73.939,-69.939,-57.939003,4.060997,30.060997,23.060997,-18.939003,-59.939003,-88.939,-96.939,-95.939,-93.939,-89.939,-87.939,-85.939,-83.939,-76.939,-73.939,-71.939,-69.939,-66.939,-61.939003,-62.939003,-62.939003,-57.939003,-54.939003,-53.939003,-50.939003,-47.939003,-45.939003,-45.939003,-43.939003,-35.939003,-36.939003,-42.939003,-40.939003,-38.939003,-37.939003,-39.939003,-40.939003,-39.939003,-42.939003,-43.939003,-34.939003,-29.939003,-31.939003,-33.939003,-32.939003,-24.939003,-21.939003,-21.939003,-15.939003,-14.939003,-18.939003,-25.939003,-26.939003,-11.939003,2.060997,13.060997,7.060997,5.060997,7.060997,3.060997,-1.939003,-3.939003,-0.939003,4.060997,2.060997,-1.939003,-9.939003,-3.939003,-4.939003,-25.939003,-42.939003,-55.939003,-49.939003,-43.939003,-38.939003,-35.939003,-34.939003,-38.939003,-37.939003,-37.939003,-37.939003,-31.939003,-20.939003,-19.939003,-20.939003,-22.939003,-22.939003,-20.939003,-18.939003,-19.939003,-24.939003,-27.939003,-19.939003,10.060997,13.060997,4.060997,8.060997,12.060997,16.060997,6.060997,0.06099701,9.060997,14.060997,17.060997,19.060997,19.060997,17.060997,15.060997,4.060997,-26.939003,-11.939003,20.060997,25.060997,25.060997,22.060997,21.060997,21.060997,22.060997,21.060997,19.060997,19.060997,-12.939003,-76.939,-85.939,-87.939,-86.939,-85.939,-85.939,-85.939,-85.939,-84.939,-83.939,-82.939,-82.939,-81.939,-79.939,-78.939,-73.939,-65.939,-62.939003,-63.939003,-73.939,-74.939,-72.939,-71.939,-71.939,-72.939,-71.939,-70.939,-69.939,-69.939,-70.939,-69.939,-69.939,-69.939,-68.939,-67.939,-66.939,-66.939,-66.939,-50.939003,-32.939003,-14.939003,-16.939003,-26.939003,-47.939003,-55.939003,-58.939003,-57.939003,-57.939003,-57.939003,-53.939003,-53.939003,-58.939003,-59.939003,-58.939003,-60.939003,-66.939,-75.939,-79.939,-80.939,-75.939,-81.939,-92.939,-93.939,-93.939,-93.939,-90.939,-86.939,-82.939,-90.939,-101.939,-84.939,-60.939003,-31.939003,-38.939003,-41.939003,-31.939003,-25.939003,-23.939003,-23.939003,-21.939003,-18.939003,-21.939003,-23.939003,-21.939003,-21.939003,-22.939003,-24.939003,-26.939003,-26.939003,-24.939003,-24.939003,-29.939003,-37.939003,-44.939003,-37.939003,-37.939003,-48.939003,-56.939003,-59.939003,-50.939003,-49.939003,-50.939003,-61.939003,-65.939,-62.939003,-60.939003,-59.939003,-55.939003,-55.939003,-56.939003,-63.939003,-65.939,-61.939003,-64.939,-68.939,-75.939,-77.939,-75.939,-72.939,-75.939,-84.939,-80.939,-75.939,-68.939,-69.939,-75.939,-75.939,-72.939,-66.939,-73.939,-78.939,74.061,72.061,69.061,66.061,64.061,64.061,64.061,61.060997,55.060997,54.060997,55.060997,52.060997,53.060997,58.060997,31.060997,-2.939003,-43.939003,-55.939003,-56.939003,-56.939003,-55.939003,-51.939003,-48.939003,-46.939003,-48.939003,-45.939003,-39.939003,-36.939003,-32.939003,-26.939003,-24.939003,-21.939003,-17.939003,-17.939003,-17.939003,-11.939003,-7.939003,-3.939003,-4.939003,-2.939003,3.060997,-1.939003,-10.939003,-13.939003,-14.939003,-14.939003,-23.939003,-33.939003,-40.939003,-43.939003,-45.939003,-50.939003,-57.939003,-65.939,-70.939,-76.939,-82.939,-87.939,-91.939,-94.939,-96.939,-100.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-101.939,-101.939,-99.939,-97.939,-97.939,-96.939,-95.939,-95.939,-93.939,-86.939,-68.939,-0.939003,79.061,60.060997,57.060997,68.061,14.060997,-38.939003,-72.939,-73.939,-62.939003,-61.939003,-59.939003,-54.939003,-52.939003,-49.939003,-46.939003,-40.939003,-33.939003,-27.939003,-22.939003,-20.939003,-17.939003,-13.939003,-7.939003,-5.939003,-6.939003,-3.939003,-1.939003,-1.939003,-3.939003,-6.939003,-5.939003,-8.939003,-11.939003,-18.939003,-23.939003,-27.939003,-29.939003,-33.939003,-40.939003,-47.939003,-53.939003,-57.939003,-62.939003,-70.939,-79.939,-86.939,-90.939,-91.939,-92.939,-98.939,-101.939,-103.939,-102.939,-101.939,-100.939,-99.939,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-92.939,-76.939,34.060997,81.061,65.061,79.061,72.061,4.060997,-45.939003,-84.939,-80.939,-79.939,-79.939,-74.939,-69.939,-67.939,-63.939003,-57.939003,-49.939003,-43.939003,-40.939003,-36.939003,-32.939003,-27.939003,-21.939003,-15.939003,-11.939003,-6.939003,-1.939003,-0.939003,1.060997,4.060997,3.060997,1.060997,-3.939003,-6.939003,-3.939003,-6.939003,-9.939003,-10.939003,-18.939003,-27.939003,-29.939003,-31.939003,-35.939003,-44.939003,-52.939003,-55.939003,-60.939003,-66.939,-70.939,-78.939,-88.939,-91.939,-93.939,-91.939,-94.939,-97.939,-100.939,-102.939,-101.939,-101.939,-86.939,-38.939003,-27.939003,-31.939003,-6.939003,29.060997,78.061,75.061,67.061,63.060997,60.060997,59.060997,57.060997,52.060997,44.060997,52.060997,50.060997,28.060997,4.060997,-15.939003,1.060997,12.060997,17.060997,22.060997,24.060997,18.060997,14.060997,9.060997,2.060997,-1.939003,-4.939003,-5.939003,-8.939003,-16.939003,-21.939003,-24.939003,-26.939003,-27.939003,-26.939003,-27.939003,-27.939003,-22.939003,-22.939003,-25.939003,-23.939003,-19.939003,-13.939003,-15.939003,-14.939003,-1.939003,16.060997,33.060997,34.060997,30.060997,22.060997,-9.939003,-41.939003,-69.939,-60.939003,-37.939003,-38.939003,-39.939003,-39.939003,-37.939003,-36.939003,-32.939003,-25.939003,-16.939003,-11.939003,-8.939003,-4.939003,6.060997,17.060997,25.060997,30.060997,32.060997,39.060997,47.060997,57.060997,58.060997,59.060997,61.060997,61.060997,59.060997,54.060997,50.060997,47.060997,45.060997,43.060997,41.060997,36.060997,29.060997,19.060997,6.060997,-10.939003,10.060997,30.060997,38.060997,34.060997,25.060997,4.060997,-25.939003,-63.939003,-40.939003,-21.939003,-29.939003,-28.939003,-25.939003,-28.939003,-30.939003,-31.939003,-32.939003,-32.939003,-31.939003,-26.939003,-21.939003,-20.939003,-14.939003,-5.939003,-5.939003,-7.939003,-8.939003,-7.939003,-4.939003,-6.939003,-8.939003,-10.939003,-14.939003,-18.939003,-22.939003,-24.939003,-26.939003,-23.939003,-28.939003,-39.939003,-46.939003,-53.939003,-57.939003,-63.939003,-70.939,-81.939,-83.939,-73.939,-67.939,-70.939,-91.939,-86.939,-73.939,-85.939,-90.939,-89.939,-88.939,-87.939,-86.939,-84.939,-82.939,-82.939,-82.939,-82.939,-81.939,-75.939,-56.939003,-17.939003,23.060997,3.060997,-12.939003,-23.939003,-40.939003,-48.939003,-31.939003,-23.939003,-20.939003,-17.939003,-14.939003,-11.939003,-7.939003,-4.939003,-2.939003,-3.939003,-3.939003,-1.939003,0.06099701,2.060997,0.06099701,-2.939003,-7.939003,-9.939003,-10.939003,-14.939003,-15.939003,-16.939003,-22.939003,-27.939003,-32.939003,-34.939003,-36.939003,-44.939003,-52.939003,-59.939003,-63.939003,-68.939,-75.939,-82.939,-90.939,-96.939,-97.939,-97.939,-57.939003,-26.939003,-26.939003,-42.939003,-60.939003,-29.939003,9.060997,56.060997,51.060997,42.060997,39.060997,30.060997,19.060997,23.060997,22.060997,16.060997,18.060997,14.060997,-2.939003,-13.939003,-19.939003,-7.939003,-1.939003,-2.939003,-9.939003,-15.939003,-19.939003,-23.939003,-26.939003,-28.939003,-31.939003,-34.939003,-36.939003,-37.939003,-39.939003,-46.939003,-53.939003,-31.939003,-28.939003,-42.939003,-46.939003,-36.939003,3.060997,8.060997,-3.939003,-16.939003,-8.939003,17.060997,0.06099701,-13.939003,-8.939003,-17.939003,-29.939003,-13.939003,5.060997,27.060997,12.060997,2.060997,12.060997,25.060997,36.060997,21.060997,11.060997,3.060997,17.060997,26.060997,16.060997,23.060997,35.060997,23.060997,14.060997,9.060997,14.060997,20.060997,26.060997,13.060997,-5.939003,-5.939003,-12.939003,-27.939003,-18.939003,-8.939003,-4.939003,-16.939003,-31.939003,-16.939003,-19.939003,-38.939003,-45.939003,-51.939003,-60.939003,-73.939,-84.939,-55.939003,-35.939003,-25.939003,-21.939003,-19.939003,-17.939003,-15.939003,-13.939003,-19.939003,-23.939003,-23.939003,-24.939003,-25.939003,-24.939003,-27.939003,-32.939003,-38.939003,-42.939003,-45.939003,-47.939003,-49.939003,-53.939003,-59.939003,-66.939,-71.939,-77.939,-82.939,-88.939,-93.939,-92.939,-94.939,-95.939,-95.939,-95.939,-93.939,-93.939,-92.939,-91.939,-82.939,-64.939,8.060997,35.060997,16.060997,-23.939003,-59.939003,-77.939,-82.939,-79.939,-72.939,-66.939,-62.939003,-60.939003,-58.939003,-48.939003,-45.939003,-44.939003,-42.939003,-39.939003,-32.939003,-36.939003,-38.939003,-32.939003,-29.939003,-28.939003,-28.939003,-30.939003,-34.939003,-34.939003,-33.939003,-28.939003,-34.939003,-42.939003,-42.939003,-42.939003,-43.939003,-41.939003,-40.939003,-43.939003,-51.939003,-58.939003,-32.939003,-11.939003,1.060997,-17.939003,-25.939003,-7.939003,-0.939003,1.060997,6.060997,5.060997,-3.939003,-18.939003,-23.939003,-3.939003,7.060997,13.060997,7.060997,2.060997,-2.939003,-8.939003,-13.939003,-16.939003,-19.939003,-20.939003,-22.939003,-29.939003,-40.939003,-40.939003,-39.939003,-37.939003,-37.939003,-38.939003,-35.939003,-32.939003,-29.939003,-24.939003,-23.939003,-30.939003,-35.939003,-39.939003,-39.939003,-32.939003,-19.939003,-19.939003,-21.939003,-22.939003,-20.939003,-16.939003,-17.939003,-19.939003,-22.939003,-25.939003,-24.939003,-17.939003,-22.939003,-31.939003,-24.939003,-19.939003,-16.939003,-23.939003,-25.939003,-16.939003,-13.939003,-12.939003,-9.939003,-7.939003,-5.939003,-4.939003,-11.939003,-32.939003,-18.939003,7.060997,10.060997,12.060997,13.060997,12.060997,12.060997,14.060997,18.060997,22.060997,20.060997,-14.939003,-84.939,-89.939,-88.939,-90.939,-91.939,-91.939,-91.939,-92.939,-92.939,-93.939,-94.939,-93.939,-93.939,-94.939,-95.939,-90.939,-77.939,-73.939,-75.939,-93.939,-97.939,-97.939,-97.939,-96.939,-94.939,-95.939,-97.939,-98.939,-99.939,-99.939,-96.939,-95.939,-97.939,-98.939,-97.939,-92.939,-95.939,-95.939,-57.939003,-16.939003,26.060997,19.060997,-4.939003,-55.939003,-74.939,-78.939,-74.939,-72.939,-72.939,-67.939,-66.939,-71.939,-70.939,-66.939,-69.939,-74.939,-79.939,-84.939,-86.939,-83.939,-86.939,-92.939,-91.939,-91.939,-95.939,-94.939,-89.939,-77.939,-87.939,-101.939,-78.939,-52.939003,-23.939003,-29.939003,-35.939003,-35.939003,-34.939003,-32.939003,-33.939003,-31.939003,-26.939003,-25.939003,-24.939003,-19.939003,-18.939003,-19.939003,-19.939003,-22.939003,-26.939003,-27.939003,-27.939003,-26.939003,-33.939003,-40.939003,-32.939003,-32.939003,-39.939003,-58.939003,-68.939,-58.939003,-49.939003,-41.939003,-55.939003,-59.939003,-54.939003,-49.939003,-45.939003,-42.939003,-43.939003,-46.939003,-59.939003,-66.939,-69.939,-75.939,-79.939,-80.939,-78.939,-75.939,-77.939,-81.939,-88.939,-83.939,-78.939,-76.939,-76.939,-77.939,-73.939,-68.939,-64.939,-70.939,-74.939,63.060997,64.061,65.061,61.060997,53.060997,42.060997,39.060997,33.060997,20.060997,12.060997,4.060997,0.06099701,-5.939003,-12.939003,-22.939003,-29.939003,-30.939003,-24.939003,-15.939003,-15.939003,-14.939003,-13.939003,-12.939003,-11.939003,-13.939003,-14.939003,-13.939003,-16.939003,-17.939003,-16.939003,-22.939003,-26.939003,-25.939003,-31.939003,-39.939003,-40.939003,-42.939003,-48.939003,-54.939003,-57.939003,-58.939003,-68.939,-78.939,-39.939003,-24.939003,-34.939003,-64.939,-87.939,-88.939,-89.939,-90.939,-91.939,-92.939,-93.939,-94.939,-95.939,-97.939,-98.939,-100.939,-101.939,-101.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-100.939,-96.939,-94.939,-92.939,-89.939,-82.939,-76.939,-74.939,-74.939,-74.939,-65.939,-54.939003,-41.939003,-19.939003,4.060997,-2.939003,-9.939003,-17.939003,-24.939003,-25.939003,-13.939003,-4.939003,2.060997,1.060997,2.060997,3.060997,1.060997,-2.939003,-6.939003,-6.939003,-6.939003,-8.939003,-14.939003,-21.939003,-26.939003,-31.939003,-34.939003,-40.939003,-47.939003,-52.939003,-56.939003,-59.939003,-68.939,-76.939,-80.939,-81.939,-82.939,-83.939,-84.939,-85.939,-86.939,-87.939,-88.939,-90.939,-92.939,-92.939,-94.939,-95.939,-97.939,-99.939,-100.939,-100.939,-101.939,-102.939,-103.939,-103.939,-103.939,-102.939,-98.939,-94.939,-92.939,-96.939,-94.939,-87.939,-86.939,-83.939,-77.939,-70.939,-58.939003,-3.939003,19.060997,10.060997,5.060997,-2.939003,-20.939003,-24.939003,-21.939003,-15.939003,-12.939003,-13.939003,-4.939003,1.060997,0.06099701,1.060997,1.060997,2.060997,1.060997,1.060997,-0.939003,-4.939003,-9.939003,-14.939003,-18.939003,-19.939003,-25.939003,-36.939003,-42.939003,-47.939003,-50.939003,-57.939003,-65.939,-72.939,-75.939,-74.939,-78.939,-81.939,-82.939,-83.939,-86.939,-86.939,-86.939,-87.939,-89.939,-90.939,-88.939,-91.939,-94.939,-95.939,-97.939,-100.939,-96.939,-94.939,-93.939,-93.939,-93.939,-98.939,-98.939,-94.939,-95.939,-84.939,-43.939003,-27.939003,-21.939003,-9.939003,8.060997,32.060997,33.060997,28.060997,14.060997,8.060997,4.060997,-0.939003,-4.939003,-8.939003,-12.939003,-17.939003,-22.939003,-20.939003,-16.939003,-18.939003,-18.939003,-17.939003,-27.939003,-32.939003,-22.939003,-20.939003,-19.939003,-11.939003,-6.939003,-5.939003,-1.939003,2.060997,4.060997,7.060997,10.060997,14.060997,17.060997,18.060997,24.060997,23.060997,6.060997,14.060997,33.060997,39.060997,41.060997,41.060997,45.060997,42.060997,17.060997,15.060997,24.060997,29.060997,30.060997,29.060997,-8.939003,-37.939003,-37.939003,2.060997,53.060997,43.060997,41.060997,49.060997,50.060997,51.060997,50.060997,52.060997,55.060997,53.060997,50.060997,45.060997,44.060997,42.060997,36.060997,29.060997,22.060997,20.060997,18.060997,15.060997,7.060997,-1.939003,-8.939003,-10.939003,-9.939003,-16.939003,-19.939003,-17.939003,-19.939003,-20.939003,-23.939003,-22.939003,-18.939003,-18.939003,-19.939003,-21.939003,6.060997,30.060997,39.060997,37.060997,31.060997,12.060997,-18.939003,-61.939003,-21.939003,8.060997,-0.939003,-3.939003,-3.939003,-0.939003,6.060997,17.060997,23.060997,22.060997,6.060997,-18.939003,-43.939003,-46.939003,-49.939003,-50.939003,-56.939003,-62.939003,-71.939,-75.939,-77.939,-80.939,-81.939,-82.939,-82.939,-83.939,-84.939,-85.939,-85.939,-85.939,-86.939,-88.939,-90.939,-91.939,-92.939,-94.939,-95.939,-96.939,-91.939,-81.939,-67.939,-61.939003,-73.939,-65.939,-48.939003,-49.939003,-46.939003,-40.939003,-37.939003,-33.939003,-28.939003,-20.939003,-12.939003,-13.939003,-12.939003,-11.939003,-7.939003,-8.939003,-19.939003,-28.939003,-35.939003,-40.939003,-37.939003,-28.939003,-18.939003,-9.939003,-6.939003,-7.939003,-8.939003,-15.939003,-19.939003,-21.939003,-26.939003,-32.939003,-37.939003,-41.939003,-44.939003,-50.939003,-54.939003,-57.939003,-62.939003,-68.939,-76.939,-80.939,-82.939,-82.939,-83.939,-83.939,-84.939,-86.939,-87.939,-87.939,-87.939,-89.939,-91.939,-93.939,-94.939,-95.939,-96.939,-98.939,-100.939,-101.939,-99.939,-92.939,-57.939003,-30.939003,-27.939003,-36.939003,-47.939003,-36.939003,-20.939003,1.060997,-5.939003,-13.939003,-20.939003,-24.939003,-27.939003,-27.939003,-28.939003,-31.939003,-36.939003,-38.939003,-27.939003,-28.939003,-34.939003,-33.939003,-37.939003,-46.939003,-44.939003,-40.939003,-35.939003,-31.939003,-30.939003,-30.939003,-28.939003,-25.939003,-17.939003,-12.939003,-14.939003,-30.939003,-48.939003,-27.939003,-27.939003,-47.939003,-31.939003,-13.939003,5.060997,7.060997,4.060997,5.060997,6.060997,10.060997,8.060997,10.060997,20.060997,19.060997,12.060997,-5.939003,3.060997,39.060997,11.060997,-9.939003,4.060997,19.060997,31.060997,30.060997,32.060997,36.060997,39.060997,41.060997,36.060997,24.060997,10.060997,4.060997,-4.939003,-17.939003,-11.939003,-2.939003,8.060997,-1.939003,-17.939003,-8.939003,-1.939003,5.060997,2.060997,-1.939003,-6.939003,-17.939003,-29.939003,-24.939003,-25.939003,-33.939003,-46.939003,-58.939003,-66.939,-74.939,-79.939,-47.939003,-32.939003,-34.939003,-38.939003,-44.939003,-56.939003,-61.939003,-64.939,-77.939,-84.939,-85.939,-85.939,-85.939,-85.939,-86.939,-87.939,-88.939,-89.939,-90.939,-90.939,-90.939,-91.939,-93.939,-94.939,-96.939,-97.939,-98.939,-95.939,-91.939,-87.939,-83.939,-77.939,-73.939,-68.939,-62.939003,-59.939003,-56.939003,-51.939003,-47.939003,-45.939003,-36.939003,-36.939003,-44.939003,-40.939003,-34.939003,-27.939003,-24.939003,-22.939003,-22.939003,-21.939003,-21.939003,-25.939003,-27.939003,-27.939003,-31.939003,-38.939003,-40.939003,-41.939003,-41.939003,-52.939003,-62.939003,-64.939,-67.939,-70.939,-73.939,-78.939,-82.939,-86.939,-87.939,-86.939,-82.939,-77.939,-59.939003,-51.939003,-54.939003,-45.939003,-39.939003,-44.939003,-58.939003,-69.939,-30.939003,-6.939003,-0.939003,-22.939003,-37.939003,-28.939003,-25.939003,-27.939003,-28.939003,-31.939003,-35.939003,-41.939003,-45.939003,-43.939003,-47.939003,-54.939003,-54.939003,-56.939003,-59.939003,-57.939003,-55.939003,-51.939003,-51.939003,-53.939003,-51.939003,-50.939003,-50.939003,-48.939003,-44.939003,-35.939003,-35.939003,-37.939003,-39.939003,-38.939003,-35.939003,-35.939003,-35.939003,-36.939003,-38.939003,-41.939003,-41.939003,-36.939003,-26.939003,-23.939003,-22.939003,-23.939003,-20.939003,-14.939003,-15.939003,-17.939003,-20.939003,-24.939003,-28.939003,-33.939003,-40.939003,-48.939003,-40.939003,-38.939003,-43.939003,-44.939003,-45.939003,-44.939003,-44.939003,-46.939003,-46.939003,-46.939003,-46.939003,-48.939003,-47.939003,-43.939003,-46.939003,-51.939003,-50.939003,-48.939003,-45.939003,-46.939003,-47.939003,-46.939003,-42.939003,-38.939003,-34.939003,-39.939003,-51.939003,-44.939003,-40.939003,-46.939003,-50.939003,-53.939003,-53.939003,-53.939003,-54.939003,-60.939003,-63.939003,-62.939003,-61.939003,-62.939003,-70.939,-69.939,-62.939003,-59.939003,-61.939003,-76.939,-79.939,-77.939,-79.939,-81.939,-83.939,-83.939,-83.939,-85.939,-88.939,-91.939,-88.939,-89.939,-93.939,-96.939,-98.939,-95.939,-96.939,-95.939,-55.939003,-10.939003,40.060997,33.060997,2.060997,-68.939,-92.939,-97.939,-96.939,-94.939,-91.939,-92.939,-95.939,-96.939,-95.939,-95.939,-95.939,-96.939,-98.939,-90.939,-84.939,-84.939,-85.939,-86.939,-93.939,-96.939,-96.939,-95.939,-92.939,-81.939,-88.939,-101.939,-78.939,-52.939003,-23.939003,-35.939003,-50.939003,-57.939003,-57.939003,-54.939003,-47.939003,-46.939003,-49.939003,-40.939003,-33.939003,-31.939003,-26.939003,-20.939003,-17.939003,-16.939003,-17.939003,-28.939003,-35.939003,-28.939003,-28.939003,-30.939003,-40.939003,-47.939003,-54.939003,-68.939,-77.939,-70.939,-58.939003,-45.939003,-55.939003,-57.939003,-52.939003,-46.939003,-41.939003,-36.939003,-35.939003,-38.939003,-44.939003,-51.939003,-60.939003,-74.939,-84.939,-84.939,-81.939,-79.939,-86.939,-84.939,-73.939,-70.939,-71.939,-81.939,-77.939,-69.939,-70.939,-70.939,-66.939,-69.939,-71.939,15.060997,15.060997,14.060997,8.060997,2.060997,-2.939003,-5.939003,-8.939003,-12.939003,-17.939003,-22.939003,-23.939003,-23.939003,-23.939003,-27.939003,-30.939003,-32.939003,-28.939003,-23.939003,-22.939003,-23.939003,-24.939003,-26.939003,-29.939003,-33.939003,-36.939003,-39.939003,-43.939003,-46.939003,-45.939003,-51.939003,-56.939003,-58.939003,-62.939003,-66.939,-69.939,-74.939,-79.939,-83.939,-85.939,-86.939,-93.939,-99.939,-55.939003,-33.939003,-33.939003,-70.939,-100.939,-102.939,-102.939,-100.939,-102.939,-100.939,-95.939,-94.939,-94.939,-90.939,-87.939,-84.939,-85.939,-83.939,-80.939,-77.939,-75.939,-71.939,-68.939,-66.939,-63.939003,-59.939003,-53.939003,-49.939003,-45.939003,-45.939003,-39.939003,-33.939003,-29.939003,-29.939003,-32.939003,-26.939003,-20.939003,-16.939003,-15.939003,-13.939003,-11.939003,-11.939003,-13.939003,-15.939003,-18.939003,-22.939003,-19.939003,-16.939003,-19.939003,-22.939003,-25.939003,-31.939003,-36.939003,-41.939003,-43.939003,-44.939003,-47.939003,-51.939003,-57.939003,-61.939003,-66.939,-69.939,-74.939,-79.939,-83.939,-86.939,-88.939,-95.939,-100.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-100.939,-98.939,-96.939,-94.939,-91.939,-87.939,-83.939,-81.939,-79.939,-75.939,-71.939,-69.939,-66.939,-63.939003,-58.939003,-52.939003,-48.939003,-50.939003,-48.939003,-42.939003,-37.939003,-33.939003,-31.939003,-29.939003,-26.939003,-12.939003,-8.939003,-14.939003,-13.939003,-13.939003,-19.939003,-21.939003,-19.939003,-15.939003,-15.939003,-21.939003,-18.939003,-17.939003,-20.939003,-24.939003,-29.939003,-33.939003,-36.939003,-37.939003,-39.939003,-42.939003,-47.939003,-52.939003,-56.939003,-57.939003,-62.939003,-69.939,-75.939,-80.939,-83.939,-88.939,-93.939,-97.939,-99.939,-98.939,-101.939,-102.939,-103.939,-103.939,-103.939,-103.939,-102.939,-101.939,-99.939,-96.939,-90.939,-89.939,-90.939,-85.939,-81.939,-77.939,-71.939,-65.939,-59.939003,-56.939003,-53.939003,-53.939003,-51.939003,-46.939003,-45.939003,-40.939003,-24.939003,-18.939003,-16.939003,-11.939003,-8.939003,-6.939003,-3.939003,-3.939003,-10.939003,-12.939003,-11.939003,-10.939003,-10.939003,-11.939003,-10.939003,-10.939003,-15.939003,-16.939003,-14.939003,-13.939003,-9.939003,-0.939003,4.060997,9.060997,14.060997,13.060997,11.060997,22.060997,27.060997,29.060997,33.060997,38.060997,38.060997,42.060997,46.060997,51.060997,52.060997,51.060997,62.060997,59.060997,24.060997,32.060997,58.060997,64.061,65.061,62.060997,66.061,57.060997,17.060997,12.060997,22.060997,32.060997,34.060997,30.060997,-4.939003,-30.939003,-27.939003,4.060997,42.060997,30.060997,26.060997,30.060997,30.060997,29.060997,30.060997,27.060997,24.060997,23.060997,20.060997,14.060997,13.060997,13.060997,11.060997,8.060997,4.060997,2.060997,-0.939003,-4.939003,-8.939003,-12.939003,-16.939003,-13.939003,-7.939003,-13.939003,-16.939003,-15.939003,-19.939003,-23.939003,-28.939003,-28.939003,-27.939003,-27.939003,-24.939003,-17.939003,9.060997,32.060997,42.060997,39.060997,31.060997,19.060997,-13.939003,-69.939,-51.939003,-34.939003,-37.939003,-41.939003,-40.939003,-5.939003,27.060997,57.060997,62.060997,53.060997,19.060997,-27.939003,-73.939,-75.939,-78.939,-81.939,-84.939,-88.939,-93.939,-93.939,-91.939,-90.939,-88.939,-85.939,-83.939,-80.939,-76.939,-73.939,-70.939,-67.939,-64.939,-61.939003,-58.939003,-57.939003,-56.939003,-54.939003,-51.939003,-48.939003,-45.939003,-41.939003,-35.939003,-30.939003,-29.939003,-26.939003,-23.939003,-21.939003,-18.939003,-14.939003,-14.939003,-13.939003,-11.939003,-9.939003,-8.939003,-9.939003,-11.939003,-12.939003,-15.939003,-18.939003,-20.939003,-14.939003,-5.939003,0.06099701,2.060997,-1.939003,-25.939003,-44.939003,-44.939003,-46.939003,-48.939003,-53.939003,-56.939003,-59.939003,-62.939003,-68.939,-72.939,-75.939,-77.939,-81.939,-84.939,-87.939,-91.939,-94.939,-96.939,-97.939,-96.939,-96.939,-94.939,-90.939,-87.939,-85.939,-82.939,-79.939,-76.939,-75.939,-72.939,-69.939,-68.939,-67.939,-67.939,-63.939003,-58.939003,-56.939003,-53.939003,-50.939003,-36.939003,-26.939003,-25.939003,-27.939003,-29.939003,-28.939003,-29.939003,-31.939003,-32.939003,-34.939003,-37.939003,-37.939003,-36.939003,-33.939003,-30.939003,-29.939003,-30.939003,-30.939003,-21.939003,-22.939003,-26.939003,-20.939003,-14.939003,-8.939003,-6.939003,-4.939003,-0.939003,3.060997,6.060997,8.060997,9.060997,9.060997,21.060997,23.060997,-2.939003,-30.939003,-53.939003,-26.939003,-23.939003,-40.939003,-21.939003,-5.939003,-6.939003,-9.939003,-11.939003,12.060997,10.060997,-15.939003,-1.939003,15.060997,30.060997,32.060997,24.060997,-10.939003,-7.939003,35.060997,16.060997,2.060997,13.060997,21.060997,29.060997,33.060997,39.060997,45.060997,35.060997,22.060997,9.060997,6.060997,8.060997,11.060997,8.060997,-0.939003,-2.939003,-3.939003,-5.939003,-6.939003,-6.939003,-6.939003,-7.939003,-8.939003,-9.939003,-12.939003,-20.939003,-25.939003,-29.939003,-27.939003,-30.939003,-35.939003,-47.939003,-57.939003,-65.939,-75.939,-81.939,-44.939003,-38.939003,-63.939003,-69.939,-74.939,-83.939,-87.939,-90.939,-97.939,-100.939,-99.939,-97.939,-95.939,-91.939,-89.939,-86.939,-84.939,-82.939,-80.939,-77.939,-75.939,-71.939,-69.939,-66.939,-64.939,-63.939003,-63.939003,-58.939003,-55.939003,-52.939003,-47.939003,-43.939003,-42.939003,-40.939003,-37.939003,-36.939003,-35.939003,-34.939003,-36.939003,-39.939003,-39.939003,-37.939003,-35.939003,-37.939003,-40.939003,-40.939003,-39.939003,-38.939003,-44.939003,-47.939003,-50.939003,-53.939003,-55.939003,-56.939003,-60.939003,-65.939,-67.939,-69.939,-70.939,-77.939,-84.939,-87.939,-89.939,-92.939,-92.939,-93.939,-95.939,-94.939,-91.939,-88.939,-81.939,-72.939,-55.939003,-49.939003,-53.939003,-46.939003,-40.939003,-41.939003,-49.939003,-57.939003,-40.939003,-30.939003,-28.939003,-37.939003,-44.939003,-46.939003,-46.939003,-45.939003,-45.939003,-46.939003,-47.939003,-50.939003,-51.939003,-48.939003,-48.939003,-48.939003,-45.939003,-43.939003,-43.939003,-40.939003,-37.939003,-32.939003,-26.939003,-21.939003,-21.939003,-19.939003,-15.939003,-13.939003,-16.939003,-30.939003,-42.939003,-49.939003,-48.939003,-47.939003,-46.939003,-46.939003,-45.939003,-43.939003,-41.939003,-40.939003,-39.939003,-33.939003,-25.939003,-22.939003,-21.939003,-21.939003,-20.939003,-17.939003,-18.939003,-20.939003,-23.939003,-23.939003,-16.939003,5.060997,1.060997,-11.939003,-5.939003,-4.939003,-9.939003,-14.939003,-16.939003,-12.939003,-12.939003,-14.939003,-13.939003,-14.939003,-17.939003,-23.939003,-29.939003,-34.939003,-30.939003,-24.939003,-27.939003,-27.939003,-24.939003,-25.939003,-27.939003,-29.939003,-27.939003,-24.939003,-29.939003,-37.939003,-50.939003,-45.939003,-42.939003,-48.939003,-52.939003,-54.939003,-51.939003,-50.939003,-51.939003,-54.939003,-56.939003,-55.939003,-52.939003,-50.939003,-57.939003,-58.939003,-52.939003,-47.939003,-47.939003,-59.939003,-60.939003,-58.939003,-58.939003,-59.939003,-61.939003,-61.939003,-61.939003,-60.939003,-62.939003,-63.939003,-64.939,-66.939,-71.939,-70.939,-70.939,-69.939,-70.939,-71.939,-54.939003,-32.939003,-5.939003,-6.939003,-19.939003,-58.939003,-70.939,-74.939,-75.939,-76.939,-75.939,-78.939,-81.939,-80.939,-79.939,-79.939,-81.939,-87.939,-96.939,-90.939,-84.939,-83.939,-82.939,-83.939,-94.939,-98.939,-94.939,-96.939,-95.939,-85.939,-92.939,-101.939,-79.939,-55.939003,-28.939003,-47.939003,-61.939003,-51.939003,-47.939003,-44.939003,-44.939003,-44.939003,-43.939003,-38.939003,-36.939003,-45.939003,-43.939003,-38.939003,-39.939003,-35.939003,-25.939003,-28.939003,-30.939003,-24.939003,-27.939003,-33.939003,-42.939003,-52.939003,-64.939,-75.939,-81.939,-74.939,-64.939,-52.939003,-48.939003,-46.939003,-49.939003,-48.939003,-47.939003,-42.939003,-39.939003,-38.939003,-39.939003,-42.939003,-49.939003,-65.939,-77.939,-79.939,-79.939,-80.939,-83.939,-78.939,-66.939,-66.939,-67.939,-70.939,-70.939,-68.939,-74.939,-77.939,-75.939,-74.939,-73.939,-38.939003,-39.939003,-44.939003,-49.939003,-50.939003,-45.939003,-45.939003,-43.939003,-34.939003,-31.939003,-30.939003,-27.939003,-18.939003,-4.939003,-8.939003,-19.939003,-41.939003,-49.939003,-51.939003,-50.939003,-52.939003,-56.939003,-60.939003,-66.939,-74.939,-79.939,-84.939,-86.939,-89.939,-88.939,-90.939,-93.939,-98.939,-96.939,-92.939,-96.939,-99.939,-99.939,-100.939,-98.939,-96.939,-98.939,-97.939,-65.939,-40.939003,-23.939003,-60.939003,-91.939,-96.939,-95.939,-91.939,-93.939,-89.939,-79.939,-79.939,-77.939,-69.939,-62.939003,-56.939003,-57.939003,-53.939003,-46.939003,-41.939003,-36.939003,-28.939003,-23.939003,-19.939003,-14.939003,-9.939003,-4.939003,3.060997,7.060997,3.060997,5.060997,8.060997,12.060997,12.060997,7.060997,5.060997,3.060997,-2.939003,-3.939003,-1.939003,5.060997,16.060997,29.060997,13.060997,-12.939003,-56.939003,-67.939,-66.939,-73.939,-79.939,-87.939,-94.939,-98.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-98.939,-97.939,-97.939,-95.939,-94.939,-92.939,-90.939,-85.939,-81.939,-77.939,-69.939,-60.939003,-53.939003,-49.939003,-44.939003,-36.939003,-29.939003,-25.939003,-19.939003,-14.939003,-10.939003,-3.939003,2.060997,3.060997,4.060997,6.060997,12.060997,17.060997,12.060997,7.060997,1.060997,-8.939003,-15.939003,-19.939003,-3.939003,6.060997,-3.939003,-24.939003,-46.939003,-45.939003,-49.939003,-59.939003,-66.939,-72.939,-77.939,-84.939,-92.939,-99.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-100.939,-97.939,-100.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-99.939,-98.939,-97.939,-97.939,-97.939,-96.939,-95.939,-93.939,-91.939,-88.939,-82.939,-74.939,-70.939,-69.939,-59.939003,-49.939003,-41.939003,-35.939003,-28.939003,-17.939003,-10.939003,-4.939003,0.06099701,4.060997,8.060997,11.060997,10.060997,1.060997,-7.939003,-15.939003,-13.939003,-18.939003,-33.939003,-29.939003,-24.939003,-19.939003,-12.939003,-5.939003,2.060997,7.060997,9.060997,22.060997,28.060997,16.060997,0.06099701,-11.939003,2.060997,17.060997,35.060997,66.061,86.061,78.061,70.061,63.060997,70.061,73.061,74.061,76.061,77.061,71.061,73.061,77.061,78.061,76.061,72.061,85.061,80.061,34.060997,38.060997,61.060997,65.061,66.061,62.060997,61.060997,49.060997,8.060997,7.060997,24.060997,37.060997,38.060997,28.060997,-0.939003,-23.939003,-27.939003,-17.939003,-4.939003,-16.939003,-22.939003,-24.939003,-27.939003,-29.939003,-24.939003,-30.939003,-38.939003,-36.939003,-37.939003,-42.939003,-36.939003,-30.939003,-23.939003,-18.939003,-15.939003,-15.939003,-14.939003,-15.939003,-11.939003,-8.939003,-5.939003,3.060997,13.060997,9.060997,7.060997,6.060997,-1.939003,-8.939003,-16.939003,-20.939003,-24.939003,-29.939003,-24.939003,-10.939003,13.060997,36.060997,46.060997,41.060997,30.060997,26.060997,-10.939003,-80.939,-95.939,-100.939,-94.939,-96.939,-93.939,-23.939003,35.060997,83.061,84.061,66.061,18.060997,-40.939003,-96.939,-96.939,-96.939,-95.939,-92.939,-91.939,-89.939,-83.939,-73.939,-68.939,-63.939003,-57.939003,-52.939003,-47.939003,-39.939003,-33.939003,-27.939003,-20.939003,-13.939003,-7.939003,-3.939003,-0.939003,0.06099701,3.060997,9.060997,13.060997,12.060997,7.060997,2.060997,1.060997,14.060997,7.060997,-5.939003,-5.939003,-5.939003,-5.939003,-10.939003,-15.939003,-18.939003,-25.939003,-34.939003,-36.939003,-40.939003,-45.939003,-56.939003,-60.939003,-39.939003,6.060997,58.060997,70.061,63.060997,35.060997,-43.939003,-103.939,-102.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-98.939,-98.939,-97.939,-97.939,-97.939,-94.939,-89.939,-85.939,-81.939,-81.939,-76.939,-67.939,-61.939003,-56.939003,-51.939003,-45.939003,-38.939003,-35.939003,-31.939003,-24.939003,-22.939003,-21.939003,-21.939003,-13.939003,-2.939003,1.060997,1.060997,-2.939003,-11.939003,-20.939003,-24.939003,-20.939003,-14.939003,-16.939003,-27.939003,-44.939003,-40.939003,-33.939003,-30.939003,-27.939003,-23.939003,-16.939003,-9.939003,-3.939003,3.060997,4.060997,-3.939003,-8.939003,-10.939003,7.060997,29.060997,55.060997,54.060997,49.060997,47.060997,49.060997,51.060997,55.060997,52.060997,44.060997,58.060997,53.060997,-1.939003,-37.939003,-59.939003,-27.939003,-17.939003,-29.939003,-12.939003,-3.939003,-21.939003,-32.939003,-33.939003,12.060997,9.060997,-43.939003,-16.939003,15.060997,33.060997,35.060997,24.060997,-19.939003,-19.939003,24.060997,22.060997,21.060997,30.060997,29.060997,26.060997,34.060997,43.060997,50.060997,24.060997,-4.939003,-30.939003,-16.939003,12.060997,24.060997,29.060997,25.060997,15.060997,3.060997,-12.939003,-6.939003,8.060997,-2.939003,-18.939003,-35.939003,-30.939003,-25.939003,-34.939003,-34.939003,-29.939003,-29.939003,-32.939003,-37.939003,-45.939003,-52.939003,-61.939003,-74.939,-83.939,-44.939003,-47.939003,-93.939,-98.939,-98.939,-97.939,-97.939,-97.939,-94.939,-92.939,-89.939,-85.939,-79.939,-71.939,-66.939,-61.939003,-56.939003,-53.939003,-48.939003,-43.939003,-37.939003,-32.939003,-26.939003,-21.939003,-18.939003,-15.939003,-15.939003,-13.939003,-13.939003,-13.939003,-13.939003,-12.939003,-17.939003,-20.939003,-23.939003,-25.939003,-28.939003,-33.939003,-37.939003,-39.939003,-21.939003,-6.939003,3.060997,-27.939003,-58.939003,-77.939,-82.939,-81.939,-91.939,-97.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-98.939,-98.939,-98.939,-98.939,-97.939,-97.939,-97.939,-93.939,-90.939,-87.939,-79.939,-71.939,-64.939,-58.939003,-51.939003,-44.939003,-42.939003,-47.939003,-45.939003,-42.939003,-36.939003,-36.939003,-38.939003,-50.939003,-58.939003,-61.939003,-52.939003,-48.939003,-57.939003,-57.939003,-53.939003,-48.939003,-47.939003,-49.939003,-49.939003,-46.939003,-36.939003,-24.939003,-11.939003,-7.939003,-4.939003,-1.939003,-0.939003,1.060997,4.060997,15.060997,28.060997,24.060997,24.060997,28.060997,29.060997,17.060997,-27.939003,-51.939003,-63.939003,-57.939003,-55.939003,-56.939003,-54.939003,-52.939003,-47.939003,-43.939003,-39.939003,-35.939003,-29.939003,-22.939003,-20.939003,-20.939003,-19.939003,-20.939003,-21.939003,-21.939003,-24.939003,-28.939003,-22.939003,-1.939003,56.060997,57.060997,36.060997,41.060997,42.060997,39.060997,29.060997,26.060997,36.060997,38.060997,36.060997,39.060997,38.060997,33.060997,22.060997,6.060997,-19.939003,-1.939003,30.060997,24.060997,22.060997,24.060997,23.060997,21.060997,16.060997,17.060997,17.060997,0.06099701,-27.939003,-63.939003,-66.939,-67.939,-69.939,-72.939,-73.939,-65.939,-63.939003,-64.939,-63.939003,-62.939003,-61.939003,-56.939003,-50.939003,-54.939003,-54.939003,-47.939003,-40.939003,-39.939003,-47.939003,-48.939003,-46.939003,-43.939003,-41.939003,-40.939003,-43.939003,-43.939003,-38.939003,-37.939003,-36.939003,-39.939003,-42.939003,-47.939003,-41.939003,-36.939003,-38.939003,-40.939003,-44.939003,-52.939003,-58.939003,-63.939003,-57.939003,-48.939003,-41.939003,-38.939003,-39.939003,-41.939003,-45.939003,-48.939003,-51.939003,-53.939003,-52.939003,-50.939003,-49.939003,-54.939003,-66.939,-87.939,-88.939,-86.939,-80.939,-80.939,-82.939,-94.939,-97.939,-92.939,-96.939,-97.939,-90.939,-95.939,-101.939,-80.939,-58.939003,-34.939003,-60.939003,-70.939,-36.939003,-26.939003,-26.939003,-37.939003,-37.939003,-26.939003,-28.939003,-37.939003,-56.939003,-61.939003,-61.939003,-69.939,-64.939,-43.939003,-30.939003,-21.939003,-19.939003,-27.939003,-38.939003,-42.939003,-53.939003,-72.939,-81.939,-83.939,-73.939,-65.939,-58.939003,-39.939003,-35.939003,-45.939003,-52.939003,-55.939003,-52.939003,-47.939003,-42.939003,-37.939003,-35.939003,-38.939003,-53.939003,-66.939,-71.939,-75.939,-77.939,-75.939,-69.939,-63.939003,-65.939,-65.939,-56.939003,-60.939003,-71.939,-80.939,-85.939,-85.939,-81.939,-78.939,-4.939003,5.060997,18.060997,19.060997,21.060997,24.060997,25.060997,28.060997,36.060997,38.060997,38.060997,36.060997,40.060997,50.060997,42.060997,12.060997,-62.939003,-82.939,-80.939,-79.939,-80.939,-83.939,-85.939,-87.939,-89.939,-90.939,-89.939,-91.939,-94.939,-97.939,-93.939,-90.939,-93.939,-93.939,-92.939,-90.939,-89.939,-87.939,-85.939,-82.939,-79.939,-76.939,-72.939,-52.939003,-37.939003,-26.939003,-34.939003,-42.939003,-45.939003,-41.939003,-35.939003,-27.939003,-22.939003,-22.939003,-18.939003,-12.939003,-7.939003,-5.939003,-3.939003,-7.939003,-8.939003,-7.939003,-10.939003,-12.939003,-13.939003,-16.939003,-20.939003,-23.939003,-26.939003,-29.939003,-31.939003,-34.939003,-41.939003,-49.939003,-56.939003,-54.939003,-56.939003,-63.939003,-65.939,-66.939,-69.939,-16.939003,53.060997,57.060997,63.060997,71.061,46.060997,4.060997,-69.939,-91.939,-91.939,-92.939,-94.939,-98.939,-100.939,-101.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-103.939,-102.939,-100.939,-95.939,-89.939,-88.939,-86.939,-85.939,-79.939,-73.939,-69.939,-66.939,-63.939003,-57.939003,-49.939003,-41.939003,-36.939003,-30.939003,-28.939003,-23.939003,-17.939003,-12.939003,-11.939003,-13.939003,-9.939003,-5.939003,-4.939003,-6.939003,-9.939003,-9.939003,-10.939003,-13.939003,-16.939003,-19.939003,-26.939003,-31.939003,-35.939003,-35.939003,-40.939003,-50.939003,-54.939003,-55.939003,-57.939003,-62.939003,-61.939003,0.06099701,37.060997,50.060997,61.060997,61.060997,33.060997,-23.939003,-84.939,-84.939,-85.939,-88.939,-91.939,-93.939,-94.939,-97.939,-99.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-100.939,-97.939,-95.939,-92.939,-91.939,-88.939,-82.939,-77.939,-71.939,-66.939,-60.939003,-54.939003,-46.939003,-39.939003,-35.939003,-32.939003,-31.939003,-27.939003,-23.939003,-18.939003,-15.939003,-13.939003,-13.939003,-15.939003,-18.939003,-16.939003,-16.939003,-20.939003,-22.939003,-24.939003,-26.939003,-27.939003,-29.939003,-29.939003,-32.939003,-38.939003,-42.939003,-40.939003,-23.939003,-22.939003,-25.939003,-21.939003,-2.939003,30.060997,37.060997,40.060997,43.060997,44.060997,45.060997,50.060997,54.060997,57.060997,62.060997,63.060997,53.060997,24.060997,-5.939003,19.060997,34.060997,40.060997,71.061,91.061,81.061,72.061,65.061,67.061,66.061,66.061,65.061,62.060997,56.060997,53.060997,52.060997,47.060997,39.060997,31.060997,30.060997,25.060997,8.060997,5.060997,9.060997,8.060997,2.060997,-8.939003,-10.939003,-11.939003,-10.939003,6.060997,29.060997,32.060997,31.060997,29.060997,1.060997,-23.939003,-36.939003,-20.939003,5.060997,-3.939003,-5.939003,-2.939003,-1.939003,-0.939003,2.060997,-0.939003,-5.939003,-7.939003,-9.939003,-9.939003,-9.939003,-12.939003,-18.939003,-22.939003,-26.939003,-31.939003,-37.939003,-46.939003,-53.939003,-58.939003,-60.939003,-61.939003,-62.939003,-65.939,-66.939,-66.939,-69.939,-71.939,-74.939,-75.939,-77.939,-78.939,-65.939,-38.939003,6.060997,42.060997,45.060997,42.060997,37.060997,27.060997,-12.939003,-80.939,-88.939,-87.939,-79.939,-73.939,-66.939,-32.939003,-4.939003,17.060997,14.060997,5.060997,-9.939003,-21.939003,-31.939003,-26.939003,-23.939003,-20.939003,-17.939003,-14.939003,-13.939003,-8.939003,-2.939003,-3.939003,-4.939003,-4.939003,-4.939003,-6.939003,-8.939003,-11.939003,-14.939003,-14.939003,-14.939003,-17.939003,-21.939003,-25.939003,-28.939003,-33.939003,-39.939003,-43.939003,-44.939003,-40.939003,-40.939003,-46.939003,-58.939003,-57.939003,-52.939003,-64.939,-70.939,-70.939,-72.939,-74.939,-75.939,-77.939,-80.939,-81.939,-82.939,-84.939,-87.939,-85.939,-66.939,3.060997,84.061,75.061,58.060997,34.060997,-44.939003,-100.939,-92.939,-86.939,-81.939,-72.939,-69.939,-70.939,-63.939003,-56.939003,-49.939003,-49.939003,-51.939003,-45.939003,-40.939003,-35.939003,-32.939003,-29.939003,-24.939003,-20.939003,-16.939003,-16.939003,-13.939003,-7.939003,-6.939003,-5.939003,-6.939003,-6.939003,-5.939003,-7.939003,-9.939003,-9.939003,-13.939003,-17.939003,-18.939003,-20.939003,-22.939003,-29.939003,-36.939003,-42.939003,-36.939003,-31.939003,-30.939003,-38.939003,-47.939003,-21.939003,1.060997,22.060997,22.060997,22.060997,27.060997,31.060997,34.060997,37.060997,40.060997,43.060997,48.060997,42.060997,9.060997,-0.939003,0.06099701,18.060997,30.060997,38.060997,31.060997,24.060997,17.060997,16.060997,16.060997,12.060997,5.060997,-4.939003,-0.939003,-3.939003,-32.939003,-46.939003,-51.939003,-30.939003,-20.939003,-20.939003,-3.939003,1.060997,-25.939003,-28.939003,-18.939003,7.060997,4.060997,-28.939003,4.060997,34.060997,42.060997,34.060997,17.060997,-11.939003,-10.939003,19.060997,5.060997,9.060997,64.061,50.060997,15.060997,44.060997,78.061,114.061,61.060997,11.060997,-5.939003,-16.939003,-22.939003,-13.939003,-12.939003,-19.939003,-12.939003,-1.939003,14.060997,6.060997,-9.939003,-7.939003,-4.939003,-0.939003,2.060997,0.06099701,-13.939003,-30.939003,-44.939003,-30.939003,-19.939003,-10.939003,-24.939003,-42.939003,-66.939,-75.939,-76.939,-54.939003,-44.939003,-45.939003,-45.939003,-43.939003,-39.939003,-38.939003,-38.939003,-33.939003,-32.939003,-34.939003,-32.939003,-28.939003,-25.939003,-25.939003,-27.939003,-28.939003,-29.939003,-28.939003,-29.939003,-32.939003,-38.939003,-41.939003,-44.939003,-47.939003,-50.939003,-53.939003,-58.939003,-64.939,-67.939,-70.939,-73.939,-74.939,-75.939,-76.939,-77.939,-78.939,-80.939,-67.939,-44.939003,17.060997,40.060997,24.060997,-25.939003,-69.939,-90.939,-93.939,-90.939,-88.939,-84.939,-79.939,-75.939,-73.939,-71.939,-63.939003,-55.939003,-54.939003,-53.939003,-50.939003,-48.939003,-46.939003,-42.939003,-40.939003,-39.939003,-41.939003,-39.939003,-32.939003,-30.939003,-29.939003,-29.939003,-34.939003,-40.939003,-44.939003,-45.939003,-44.939003,-40.939003,-37.939003,-37.939003,-43.939003,-50.939003,-31.939003,-23.939003,-24.939003,-31.939003,-31.939003,-18.939003,-12.939003,-8.939003,-1.939003,-3.939003,-16.939003,-22.939003,-20.939003,-5.939003,4.060997,13.060997,9.060997,4.060997,-1.939003,-1.939003,-0.939003,-2.939003,-4.939003,-6.939003,-14.939003,-16.939003,-10.939003,-16.939003,-23.939003,-34.939003,-38.939003,-39.939003,-36.939003,-34.939003,-33.939003,-29.939003,-28.939003,-32.939003,-35.939003,-38.939003,-36.939003,-29.939003,-18.939003,-21.939003,-23.939003,-22.939003,-20.939003,-16.939003,-16.939003,-20.939003,-26.939003,-25.939003,-18.939003,-2.939003,-4.939003,-12.939003,-6.939003,-3.939003,-5.939003,-8.939003,-8.939003,1.060997,4.060997,5.060997,8.060997,12.060997,16.060997,5.060997,-8.939003,-26.939003,-7.939003,23.060997,19.060997,16.060997,15.060997,20.060997,23.060997,19.060997,26.060997,32.060997,5.060997,-34.939003,-86.939,-89.939,-90.939,-91.939,-93.939,-93.939,-90.939,-90.939,-90.939,-90.939,-89.939,-89.939,-87.939,-85.939,-87.939,-81.939,-67.939,-74.939,-82.939,-84.939,-85.939,-84.939,-83.939,-82.939,-82.939,-83.939,-83.939,-81.939,-81.939,-81.939,-82.939,-83.939,-84.939,-82.939,-81.939,-81.939,-82.939,-79.939,-45.939003,-20.939003,-2.939003,-7.939003,-23.939003,-64.939,-72.939,-66.939,-66.939,-66.939,-67.939,-64.939,-63.939003,-62.939003,-61.939003,-61.939003,-61.939003,-69.939,-85.939,-85.939,-81.939,-73.939,-76.939,-83.939,-91.939,-94.939,-96.939,-97.939,-96.939,-90.939,-95.939,-100.939,-71.939,-50.939003,-37.939003,-55.939003,-63.939003,-41.939003,-39.939003,-43.939003,-49.939003,-42.939003,-21.939003,-25.939003,-35.939003,-49.939003,-52.939003,-53.939003,-72.939,-83.939,-82.939,-59.939003,-36.939003,-26.939003,-22.939003,-22.939003,-38.939003,-56.939003,-75.939,-83.939,-80.939,-56.939003,-48.939003,-50.939003,-49.939003,-45.939003,-37.939003,-39.939003,-43.939003,-54.939003,-51.939003,-44.939003,-40.939003,-40.939003,-41.939003,-53.939003,-61.939003,-62.939003,-62.939003,-62.939003,-64.939,-63.939003,-62.939003,-66.939,-68.939,-62.939003,-64.939,-72.939,-79.939,-84.939,-84.939,-84.939,-83.939,22.060997,38.060997,59.060997,60.060997,61.060997,63.060997,64.061,65.061,71.061,71.061,69.061,68.061,70.061,75.061,68.061,32.060997,-65.939,-94.939,-94.939,-90.939,-89.939,-91.939,-91.939,-91.939,-89.939,-84.939,-79.939,-79.939,-80.939,-83.939,-75.939,-68.939,-68.939,-67.939,-66.939,-60.939003,-56.939003,-54.939003,-51.939003,-48.939003,-45.939003,-42.939003,-39.939003,-30.939003,-25.939003,-23.939003,-18.939003,-16.939003,-18.939003,-15.939003,-12.939003,-6.939003,-6.939003,-11.939003,-9.939003,-7.939003,-5.939003,-8.939003,-10.939003,-15.939003,-18.939003,-21.939003,-26.939003,-31.939003,-34.939003,-39.939003,-44.939003,-49.939003,-55.939003,-60.939003,-66.939,-70.939,-76.939,-84.939,-93.939,-93.939,-95.939,-99.939,-101.939,-102.939,-103.939,-25.939003,73.061,81.061,86.061,88.061,62.060997,14.060997,-76.939,-100.939,-95.939,-93.939,-91.939,-89.939,-87.939,-85.939,-84.939,-82.939,-79.939,-77.939,-75.939,-72.939,-71.939,-69.939,-64.939,-58.939003,-51.939003,-51.939003,-50.939003,-47.939003,-41.939003,-36.939003,-35.939003,-32.939003,-29.939003,-25.939003,-20.939003,-15.939003,-12.939003,-10.939003,-10.939003,-7.939003,-4.939003,-3.939003,-5.939003,-11.939003,-12.939003,-13.939003,-15.939003,-20.939003,-26.939003,-28.939003,-32.939003,-37.939003,-42.939003,-47.939003,-55.939003,-62.939003,-69.939,-70.939,-75.939,-85.939,-92.939,-95.939,-95.939,-99.939,-93.939,2.060997,62.060997,88.061,94.061,86.061,54.060997,-20.939003,-103.939,-99.939,-97.939,-96.939,-93.939,-90.939,-88.939,-86.939,-84.939,-82.939,-80.939,-77.939,-74.939,-72.939,-69.939,-66.939,-62.939003,-60.939003,-57.939003,-53.939003,-51.939003,-48.939003,-44.939003,-40.939003,-35.939003,-30.939003,-27.939003,-25.939003,-20.939003,-17.939003,-16.939003,-15.939003,-16.939003,-15.939003,-14.939003,-10.939003,-9.939003,-10.939003,-14.939003,-16.939003,-17.939003,-17.939003,-19.939003,-24.939003,-25.939003,-25.939003,-27.939003,-28.939003,-28.939003,-28.939003,-29.939003,-31.939003,-31.939003,-27.939003,-15.939003,-7.939003,-2.939003,-7.939003,13.060997,59.060997,69.061,72.061,70.061,67.061,65.061,66.061,68.061,69.061,70.061,67.061,56.060997,25.060997,-7.939003,16.060997,28.060997,29.060997,48.060997,61.060997,50.060997,43.060997,38.060997,35.060997,32.060997,29.060997,27.060997,23.060997,21.060997,17.060997,14.060997,9.060997,2.060997,-5.939003,-7.939003,-8.939003,-11.939003,-13.939003,-14.939003,-13.939003,-15.939003,-20.939003,-18.939003,-16.939003,-13.939003,3.060997,25.060997,31.060997,32.060997,30.060997,4.060997,-22.939003,-45.939003,-31.939003,-5.939003,-11.939003,-15.939003,-18.939003,-18.939003,-18.939003,-20.939003,-22.939003,-24.939003,-27.939003,-29.939003,-27.939003,-29.939003,-34.939003,-42.939003,-49.939003,-53.939003,-59.939003,-66.939,-75.939,-83.939,-89.939,-93.939,-96.939,-99.939,-98.939,-97.939,-95.939,-91.939,-87.939,-84.939,-81.939,-79.939,-78.939,-66.939,-44.939003,2.060997,37.060997,39.060997,38.060997,35.060997,22.060997,-13.939003,-72.939,-66.939,-52.939003,-43.939003,-37.939003,-32.939003,-23.939003,-19.939003,-18.939003,-22.939003,-23.939003,-19.939003,-13.939003,-6.939003,-2.939003,-1.939003,-3.939003,-2.939003,-1.939003,-2.939003,-0.939003,0.06099701,-5.939003,-10.939003,-15.939003,-18.939003,-22.939003,-27.939003,-32.939003,-38.939003,-40.939003,-42.939003,-47.939003,-52.939003,-57.939003,-61.939003,-67.939,-75.939,-82.939,-81.939,-74.939,-67.939,-71.939,-95.939,-90.939,-75.939,-91.939,-99.939,-100.939,-96.939,-93.939,-90.939,-88.939,-85.939,-83.939,-82.939,-80.939,-79.939,-74.939,-61.939003,-9.939003,51.060997,36.060997,23.060997,10.060997,-34.939003,-64.939,-53.939003,-46.939003,-42.939003,-34.939003,-32.939003,-33.939003,-27.939003,-19.939003,-12.939003,-13.939003,-17.939003,-16.939003,-13.939003,-9.939003,-7.939003,-6.939003,-6.939003,-6.939003,-5.939003,-7.939003,-7.939003,-5.939003,-9.939003,-11.939003,-17.939003,-21.939003,-24.939003,-27.939003,-31.939003,-33.939003,-38.939003,-42.939003,-44.939003,-48.939003,-54.939003,-62.939003,-69.939,-75.939,-52.939003,-32.939003,-28.939003,-45.939003,-65.939,-26.939003,11.060997,46.060997,38.060997,31.060997,34.060997,36.060997,36.060997,36.060997,34.060997,33.060997,34.060997,27.060997,-1.939003,-10.939003,-8.939003,1.060997,5.060997,2.060997,-4.939003,-11.939003,-14.939003,-15.939003,-16.939003,-20.939003,-26.939003,-33.939003,-31.939003,-33.939003,-46.939003,-48.939003,-44.939003,-29.939003,-20.939003,-15.939003,0.06099701,2.060997,-33.939003,-34.939003,-18.939003,6.060997,9.060997,-10.939003,9.060997,29.060997,37.060997,27.060997,8.060997,-12.939003,-13.939003,5.060997,-5.939003,1.060997,59.060997,40.060997,0.06099701,52.060997,101.061,147.061,82.061,25.060997,10.060997,-2.939003,-13.939003,-8.939003,-10.939003,-19.939003,-16.939003,-9.939003,5.060997,-0.939003,-13.939003,3.060997,13.060997,16.060997,17.060997,13.060997,1.060997,-24.939003,-51.939003,-28.939003,-11.939003,0.06099701,-15.939003,-36.939003,-59.939003,-69.939,-72.939,-57.939003,-40.939003,-19.939003,-20.939003,-23.939003,-23.939003,-24.939003,-25.939003,-22.939003,-22.939003,-27.939003,-29.939003,-30.939003,-31.939003,-35.939003,-39.939003,-42.939003,-44.939003,-45.939003,-47.939003,-52.939003,-60.939003,-66.939,-71.939,-75.939,-78.939,-82.939,-88.939,-93.939,-94.939,-95.939,-96.939,-95.939,-93.939,-92.939,-90.939,-88.939,-88.939,-72.939,-46.939003,7.060997,25.060997,7.060997,-28.939003,-58.939003,-71.939,-70.939,-64.939,-60.939003,-55.939003,-49.939003,-46.939003,-44.939003,-44.939003,-37.939003,-28.939003,-29.939003,-29.939003,-28.939003,-29.939003,-30.939003,-29.939003,-29.939003,-28.939003,-33.939003,-35.939003,-33.939003,-34.939003,-37.939003,-40.939003,-45.939003,-49.939003,-45.939003,-44.939003,-45.939003,-40.939003,-37.939003,-41.939003,-49.939003,-55.939003,-20.939003,-4.939003,-6.939003,-19.939003,-24.939003,-7.939003,-1.939003,-0.939003,4.060997,0.06099701,-12.939003,-20.939003,-21.939003,-11.939003,-7.939003,-5.939003,-12.939003,-17.939003,-21.939003,-21.939003,-20.939003,-23.939003,-29.939003,-36.939003,-41.939003,-41.939003,-37.939003,-38.939003,-37.939003,-35.939003,-31.939003,-28.939003,-28.939003,-27.939003,-25.939003,-23.939003,-23.939003,-28.939003,-33.939003,-37.939003,-36.939003,-30.939003,-17.939003,-21.939003,-24.939003,-22.939003,-19.939003,-14.939003,-16.939003,-20.939003,-27.939003,-26.939003,-25.939003,-29.939003,-35.939003,-41.939003,-34.939003,-31.939003,-33.939003,-33.939003,-32.939003,-27.939003,-25.939003,-24.939003,-23.939003,-19.939003,-13.939003,-19.939003,-26.939003,-34.939003,-24.939003,-7.939003,-10.939003,-11.939003,-11.939003,-7.939003,-4.939003,-5.939003,0.06099701,6.060997,-10.939003,-37.939003,-73.939,-77.939,-78.939,-79.939,-80.939,-80.939,-82.939,-83.939,-84.939,-84.939,-84.939,-85.939,-86.939,-87.939,-86.939,-80.939,-69.939,-80.939,-90.939,-90.939,-91.939,-92.939,-94.939,-94.939,-92.939,-94.939,-95.939,-96.939,-96.939,-96.939,-97.939,-98.939,-98.939,-100.939,-101.939,-100.939,-101.939,-97.939,-40.939003,0.06099701,28.060997,14.060997,-17.939003,-79.939,-93.939,-87.939,-86.939,-85.939,-84.939,-81.939,-79.939,-78.939,-78.939,-78.939,-77.939,-81.939,-89.939,-84.939,-78.939,-71.939,-74.939,-83.939,-90.939,-94.939,-97.939,-97.939,-95.939,-90.939,-95.939,-99.939,-61.939003,-39.939003,-32.939003,-47.939003,-55.939003,-42.939003,-42.939003,-48.939003,-51.939003,-42.939003,-23.939003,-24.939003,-30.939003,-44.939003,-49.939003,-51.939003,-72.939,-88.939,-97.939,-75.939,-51.939003,-37.939003,-29.939003,-26.939003,-40.939003,-55.939003,-72.939,-79.939,-76.939,-51.939003,-43.939003,-45.939003,-51.939003,-47.939003,-31.939003,-33.939003,-39.939003,-53.939003,-50.939003,-42.939003,-43.939003,-44.939003,-43.939003,-49.939003,-54.939003,-54.939003,-53.939003,-53.939003,-54.939003,-56.939003,-58.939003,-65.939,-68.939,-62.939003,-64.939,-71.939,-77.939,-82.939,-84.939,-84.939,-84.939,43.060997,59.060997,79.061,73.061,69.061,70.061,69.061,68.061,68.061,66.061,63.060997,68.061,71.061,70.061,69.061,40.060997,-49.939003,-84.939,-94.939,-83.939,-78.939,-80.939,-80.939,-78.939,-72.939,-63.939003,-54.939003,-50.939003,-47.939003,-46.939003,-36.939003,-27.939003,-23.939003,-19.939003,-14.939003,-5.939003,-1.939003,0.06099701,2.060997,4.060997,3.060997,2.060997,1.060997,-1.939003,-6.939003,-13.939003,-13.939003,-13.939003,-15.939003,-18.939003,-21.939003,-31.939003,-40.939003,-46.939003,-54.939003,-60.939003,-64.939,-71.939,-77.939,-82.939,-85.939,-88.939,-91.939,-93.939,-93.939,-93.939,-92.939,-94.939,-95.939,-97.939,-99.939,-99.939,-100.939,-100.939,-102.939,-102.939,-102.939,-102.939,-103.939,-103.939,-102.939,-32.939003,58.060997,76.061,83.061,81.061,59.060997,15.060997,-76.939,-94.939,-80.939,-76.939,-69.939,-62.939003,-55.939003,-50.939003,-46.939003,-39.939003,-32.939003,-24.939003,-18.939003,-12.939003,-6.939003,-2.939003,2.060997,6.060997,11.060997,7.060997,8.060997,13.060997,11.060997,8.060997,3.060997,2.060997,1.060997,-2.939003,-9.939003,-19.939003,-27.939003,-34.939003,-39.939003,-46.939003,-53.939003,-57.939003,-63.939003,-72.939,-78.939,-84.939,-87.939,-90.939,-93.939,-94.939,-96.939,-97.939,-97.939,-98.939,-97.939,-97.939,-99.939,-100.939,-100.939,-100.939,-102.939,-102.939,-102.939,-103.939,-94.939,-3.939003,59.060997,95.061,93.061,81.061,59.060997,-15.939003,-102.939,-92.939,-86.939,-83.939,-73.939,-64.939,-59.939003,-53.939003,-46.939003,-41.939003,-34.939003,-24.939003,-17.939003,-9.939003,-2.939003,4.060997,9.060997,10.060997,13.060997,18.060997,17.060997,16.060997,12.060997,10.060997,9.060997,6.060997,-0.939003,-11.939003,-21.939003,-31.939003,-40.939003,-46.939003,-52.939003,-60.939003,-65.939,-67.939,-69.939,-71.939,-76.939,-74.939,-68.939,-62.939003,-56.939003,-53.939003,-43.939003,-32.939003,-22.939003,-12.939003,-1.939003,4.060997,14.060997,28.060997,44.060997,49.060997,25.060997,35.060997,52.060997,27.060997,28.060997,54.060997,67.061,73.061,62.060997,57.060997,55.060997,51.060997,47.060997,46.060997,46.060997,42.060997,26.060997,3.060997,-18.939003,-8.939003,-1.939003,1.060997,-0.939003,-3.939003,-12.939003,-15.939003,-17.939003,-24.939003,-30.939003,-36.939003,-37.939003,-37.939003,-33.939003,-33.939003,-34.939003,-35.939003,-36.939003,-38.939003,-30.939003,-23.939003,-24.939003,-18.939003,-10.939003,1.060997,12.060997,24.060997,36.060997,34.060997,0.06099701,-0.939003,14.060997,34.060997,39.060997,33.060997,7.060997,-21.939003,-54.939003,-52.939003,-36.939003,-40.939003,-52.939003,-72.939,-78.939,-84.939,-93.939,-96.939,-96.939,-96.939,-96.939,-96.939,-96.939,-97.939,-97.939,-98.939,-98.939,-99.939,-100.939,-100.939,-101.939,-102.939,-102.939,-101.939,-98.939,-91.939,-84.939,-78.939,-66.939,-56.939003,-47.939003,-38.939003,-31.939003,-27.939003,-26.939003,-27.939003,-0.939003,22.060997,30.060997,29.060997,25.060997,12.060997,-14.939003,-56.939003,-27.939003,2.060997,11.060997,11.060997,6.060997,1.060997,-8.939003,-23.939003,-25.939003,-21.939003,-11.939003,-15.939003,-23.939003,-24.939003,-30.939003,-42.939003,-47.939003,-52.939003,-55.939003,-58.939003,-63.939003,-73.939,-83.939,-90.939,-93.939,-96.939,-96.939,-97.939,-97.939,-97.939,-97.939,-98.939,-97.939,-98.939,-99.939,-100.939,-100.939,-101.939,-98.939,-93.939,-77.939,-72.939,-96.939,-90.939,-74.939,-85.939,-91.939,-94.939,-82.939,-72.939,-65.939,-57.939003,-49.939003,-44.939003,-39.939003,-35.939003,-31.939003,-27.939003,-24.939003,-30.939003,-40.939003,-46.939003,-43.939003,-34.939003,-13.939003,4.060997,13.060997,16.060997,15.060997,12.060997,10.060997,9.060997,9.060997,11.060997,14.060997,11.060997,3.060997,-10.939003,-18.939003,-18.939003,-20.939003,-24.939003,-35.939003,-41.939003,-47.939003,-52.939003,-58.939003,-62.939003,-70.939,-76.939,-83.939,-90.939,-96.939,-96.939,-96.939,-96.939,-95.939,-96.939,-97.939,-98.939,-99.939,-97.939,-97.939,-100.939,-58.939003,-23.939003,-18.939003,-40.939003,-68.939,-29.939003,2.060997,26.060997,8.060997,-7.939003,-9.939003,-12.939003,-17.939003,-20.939003,-26.939003,-33.939003,-38.939003,-40.939003,-38.939003,-38.939003,-39.939003,-43.939003,-47.939003,-52.939003,-55.939003,-56.939003,-49.939003,-46.939003,-45.939003,-44.939003,-42.939003,-41.939003,-36.939003,-35.939003,-44.939003,-43.939003,-38.939003,-24.939003,-17.939003,-14.939003,1.060997,0.06099701,-44.939003,-49.939003,-33.939003,10.060997,24.060997,8.060997,0.06099701,0.06099701,17.060997,12.060997,-3.939003,-22.939003,-26.939003,-15.939003,-10.939003,-3.939003,14.060997,-0.939003,-17.939003,57.060997,112.061,148.061,89.061,36.060997,18.060997,24.060997,41.060997,40.060997,34.060997,25.060997,2.060997,-19.939003,-37.939003,-26.939003,-3.939003,28.060997,34.060997,15.060997,14.060997,14.060997,11.060997,-16.939003,-48.939003,-22.939003,-8.939003,-6.939003,-20.939003,-34.939003,-42.939003,-56.939003,-70.939,-53.939003,-35.939003,-15.939003,-24.939003,-36.939003,-49.939003,-56.939003,-59.939003,-59.939003,-62.939003,-67.939,-76.939,-85.939,-89.939,-94.939,-97.939,-97.939,-98.939,-98.939,-98.939,-98.939,-99.939,-100.939,-100.939,-100.939,-101.939,-101.939,-102.939,-100.939,-95.939,-90.939,-83.939,-78.939,-73.939,-69.939,-63.939003,-58.939003,-57.939003,-52.939003,-46.939003,-49.939003,-50.939003,-49.939003,-35.939003,-24.939003,-21.939003,-13.939003,-3.939003,-7.939003,-10.939003,-11.939003,-11.939003,-13.939003,-20.939003,-21.939003,-18.939003,-23.939003,-27.939003,-32.939003,-41.939003,-50.939003,-57.939003,-62.939003,-65.939,-70.939,-79.939,-90.939,-93.939,-95.939,-97.939,-90.939,-78.939,-48.939003,-39.939003,-51.939003,-44.939003,-40.939003,-49.939003,-54.939003,-54.939003,-19.939003,-3.939003,-7.939003,-18.939003,-27.939003,-24.939003,-25.939003,-27.939003,-32.939003,-34.939003,-35.939003,-43.939003,-50.939003,-54.939003,-60.939003,-68.939,-75.939,-72.939,-62.939003,-60.939003,-59.939003,-58.939003,-58.939003,-59.939003,-55.939003,-52.939003,-51.939003,-35.939003,-24.939003,-28.939003,-29.939003,-31.939003,-33.939003,-32.939003,-30.939003,-34.939003,-37.939003,-36.939003,-37.939003,-36.939003,-37.939003,-31.939003,-20.939003,-20.939003,-21.939003,-19.939003,-17.939003,-14.939003,-19.939003,-25.939003,-31.939003,-25.939003,-21.939003,-25.939003,-36.939003,-48.939003,-43.939003,-42.939003,-45.939003,-46.939003,-47.939003,-49.939003,-50.939003,-51.939003,-56.939003,-58.939003,-58.939003,-53.939003,-47.939003,-43.939003,-51.939003,-62.939003,-63.939003,-61.939003,-56.939003,-60.939003,-62.939003,-60.939003,-60.939003,-60.939003,-47.939003,-35.939003,-24.939003,-29.939003,-33.939003,-34.939003,-34.939003,-34.939003,-39.939003,-43.939003,-45.939003,-45.939003,-46.939003,-48.939003,-51.939003,-56.939003,-53.939003,-53.939003,-53.939003,-57.939003,-63.939003,-65.939,-67.939,-70.939,-75.939,-75.939,-70.939,-75.939,-80.939,-81.939,-82.939,-83.939,-86.939,-87.939,-87.939,-93.939,-96.939,-95.939,-98.939,-97.939,-37.939003,5.060997,30.060997,7.060997,-30.939003,-85.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-98.939,-87.939,-76.939,-72.939,-74.939,-80.939,-91.939,-96.939,-95.939,-94.939,-93.939,-91.939,-96.939,-97.939,-51.939003,-24.939003,-18.939003,-34.939003,-45.939003,-38.939003,-37.939003,-40.939003,-42.939003,-39.939003,-31.939003,-23.939003,-23.939003,-43.939003,-51.939003,-55.939003,-68.939,-79.939,-86.939,-78.939,-66.939,-52.939003,-50.939003,-51.939003,-49.939003,-52.939003,-61.939003,-70.939,-72.939,-60.939003,-50.939003,-43.939003,-46.939003,-42.939003,-28.939003,-35.939003,-43.939003,-48.939003,-44.939003,-37.939003,-45.939003,-47.939003,-42.939003,-42.939003,-43.939003,-47.939003,-48.939003,-49.939003,-47.939003,-48.939003,-53.939003,-61.939003,-64.939,-56.939003,-59.939003,-69.939,-74.939,-79.939,-83.939,-81.939,-79.939,38.060997,56.060997,78.061,68.061,61.060997,58.060997,54.060997,50.060997,46.060997,40.060997,33.060997,28.060997,25.060997,23.060997,17.060997,2.060997,-31.939003,-40.939003,-37.939003,-31.939003,-26.939003,-23.939003,-24.939003,-24.939003,-20.939003,-18.939003,-17.939003,-16.939003,-16.939003,-16.939003,-17.939003,-18.939003,-18.939003,-21.939003,-25.939003,-24.939003,-25.939003,-29.939003,-33.939003,-36.939003,-40.939003,-45.939003,-50.939003,-45.939003,-33.939003,-16.939003,-35.939003,-55.939003,-62.939003,-66.939,-68.939,-73.939,-77.939,-79.939,-83.939,-86.939,-89.939,-92.939,-95.939,-97.939,-98.939,-99.939,-99.939,-98.939,-99.939,-98.939,-97.939,-97.939,-96.939,-95.939,-95.939,-94.939,-92.939,-87.939,-83.939,-82.939,-78.939,-73.939,-68.939,-62.939003,-56.939003,-25.939003,13.060997,16.060997,14.060997,10.060997,-0.939003,-13.939003,-33.939003,-30.939003,-18.939003,-16.939003,-13.939003,-9.939003,-10.939003,-13.939003,-16.939003,-15.939003,-13.939003,-16.939003,-18.939003,-19.939003,-21.939003,-23.939003,-25.939003,-28.939003,-30.939003,-34.939003,-38.939003,-42.939003,-48.939003,-52.939003,-54.939003,-55.939003,-56.939003,-58.939003,-61.939003,-66.939,-70.939,-74.939,-76.939,-80.939,-83.939,-85.939,-88.939,-92.939,-95.939,-98.939,-99.939,-100.939,-102.939,-102.939,-102.939,-101.939,-102.939,-101.939,-98.939,-96.939,-96.939,-91.939,-86.939,-82.939,-80.939,-76.939,-69.939,-64.939,-57.939003,-11.939003,19.060997,33.060997,25.060997,15.060997,9.060997,-10.939003,-34.939003,-23.939003,-18.939003,-18.939003,-15.939003,-14.939003,-14.939003,-14.939003,-13.939003,-18.939003,-20.939003,-20.939003,-22.939003,-24.939003,-25.939003,-26.939003,-29.939003,-31.939003,-33.939003,-34.939003,-37.939003,-41.939003,-46.939003,-42.939003,-37.939003,-30.939003,-29.939003,-33.939003,-30.939003,-27.939003,-27.939003,-24.939003,-22.939003,-17.939003,-11.939003,-3.939003,3.060997,8.060997,9.060997,14.060997,21.060997,25.060997,29.060997,35.060997,37.060997,39.060997,43.060997,43.060997,41.060997,39.060997,37.060997,38.060997,42.060997,38.060997,15.060997,14.060997,21.060997,8.060997,3.060997,7.060997,12.060997,14.060997,8.060997,4.060997,2.060997,0.06099701,-2.939003,-2.939003,-3.939003,-7.939003,-13.939003,-15.939003,-15.939003,-13.939003,-11.939003,-8.939003,-7.939003,-5.939003,-5.939003,-3.939003,-1.939003,-2.939003,-3.939003,-3.939003,0.06099701,4.060997,9.060997,13.060997,15.060997,15.060997,15.060997,16.060997,24.060997,25.060997,11.060997,17.060997,30.060997,41.060997,40.060997,29.060997,29.060997,23.060997,-0.939003,2.060997,15.060997,29.060997,32.060997,25.060997,15.060997,-3.939003,-47.939003,-66.939,-75.939,-75.939,-79.939,-88.939,-93.939,-95.939,-93.939,-89.939,-87.939,-82.939,-78.939,-73.939,-69.939,-64.939,-59.939003,-55.939003,-51.939003,-47.939003,-43.939003,-37.939003,-37.939003,-35.939003,-30.939003,-28.939003,-27.939003,-22.939003,-20.939003,-19.939003,-15.939003,-15.939003,-24.939003,-27.939003,-28.939003,-35.939003,-30.939003,-14.939003,8.060997,26.060997,32.060997,31.060997,26.060997,19.060997,-0.939003,-34.939003,-48.939003,-52.939003,-39.939003,-43.939003,-51.939003,-30.939003,-4.939003,26.060997,36.060997,35.060997,10.060997,-28.939003,-69.939,-69.939,-72.939,-78.939,-81.939,-83.939,-84.939,-86.939,-88.939,-93.939,-97.939,-101.939,-101.939,-100.939,-97.939,-93.939,-89.939,-86.939,-83.939,-80.939,-76.939,-72.939,-66.939,-63.939003,-61.939003,-56.939003,-51.939003,-46.939003,-38.939003,-34.939003,-37.939003,-35.939003,-31.939003,-32.939003,-30.939003,-25.939003,-24.939003,-20.939003,-10.939003,-8.939003,-8.939003,-8.939003,-9.939003,-10.939003,-12.939003,-16.939003,-20.939003,-21.939003,-21.939003,-17.939003,-15.939003,-16.939003,-23.939003,-29.939003,-31.939003,-33.939003,-35.939003,-42.939003,-47.939003,-49.939003,-51.939003,-51.939003,-50.939003,-52.939003,-55.939003,-62.939003,-66.939,-66.939,-67.939,-69.939,-75.939,-78.939,-80.939,-83.939,-85.939,-88.939,-90.939,-91.939,-92.939,-91.939,-90.939,-87.939,-84.939,-78.939,-75.939,-72.939,-67.939,-62.939003,-59.939003,-56.939003,-53.939003,-48.939003,-33.939003,-20.939003,-19.939003,-24.939003,-31.939003,-26.939003,-22.939003,-20.939003,-25.939003,-32.939003,-37.939003,-37.939003,-37.939003,-36.939003,-36.939003,-38.939003,-37.939003,-35.939003,-32.939003,-29.939003,-25.939003,-24.939003,-20.939003,-11.939003,-14.939003,-15.939003,-6.939003,-2.939003,1.060997,5.060997,7.060997,8.060997,17.060997,12.060997,-29.939003,-45.939003,-48.939003,-29.939003,-18.939003,-15.939003,-1.939003,2.060997,-22.939003,-34.939003,-33.939003,12.060997,24.060997,4.060997,-6.939003,-13.939003,-16.939003,4.060997,30.060997,-0.939003,-19.939003,-24.939003,0.06099701,18.060997,10.060997,10.060997,17.060997,49.060997,70.061,80.061,61.060997,43.060997,34.060997,34.060997,39.060997,37.060997,37.060997,39.060997,24.060997,5.060997,-19.939003,-9.939003,14.060997,26.060997,25.060997,12.060997,12.060997,12.060997,12.060997,2.060997,-9.939003,-6.939003,-9.939003,-17.939003,-27.939003,-36.939003,-39.939003,-50.939003,-63.939003,-64.939,-65.939,-64.939,-68.939,-74.939,-80.939,-83.939,-85.939,-85.939,-85.939,-86.939,-87.939,-88.939,-84.939,-83.939,-80.939,-80.939,-78.939,-75.939,-71.939,-67.939,-64.939,-61.939003,-60.939003,-58.939003,-55.939003,-52.939003,-50.939003,-47.939003,-45.939003,-42.939003,-37.939003,-36.939003,-35.939003,-35.939003,-32.939003,-31.939003,-34.939003,-35.939003,-36.939003,-37.939003,-35.939003,-30.939003,-35.939003,-42.939003,-46.939003,-47.939003,-48.939003,-54.939003,-58.939003,-61.939003,-61.939003,-62.939003,-66.939,-66.939,-65.939,-68.939,-70.939,-72.939,-75.939,-78.939,-81.939,-82.939,-83.939,-83.939,-84.939,-86.939,-83.939,-80.939,-78.939,-70.939,-58.939003,-43.939003,-37.939003,-41.939003,-38.939003,-37.939003,-40.939003,-40.939003,-38.939003,-37.939003,-36.939003,-36.939003,-35.939003,-36.939003,-42.939003,-41.939003,-38.939003,-40.939003,-39.939003,-34.939003,-39.939003,-43.939003,-42.939003,-40.939003,-39.939003,-38.939003,-33.939003,-26.939003,-21.939003,-17.939003,-15.939003,-9.939003,-3.939003,-6.939003,-7.939003,-4.939003,1.060997,-3.939003,-37.939003,-46.939003,-44.939003,-46.939003,-46.939003,-44.939003,-46.939003,-46.939003,-42.939003,-40.939003,-37.939003,-32.939003,-26.939003,-22.939003,-23.939003,-23.939003,-18.939003,-16.939003,-16.939003,-18.939003,-26.939003,-38.939003,-20.939003,1.060997,27.060997,18.060997,-1.939003,6.060997,7.060997,-0.939003,-2.939003,-1.939003,6.060997,6.060997,5.060997,5.060997,3.060997,-0.939003,-14.939003,-25.939003,-25.939003,-15.939003,-3.939003,-7.939003,-9.939003,-8.939003,-10.939003,-12.939003,-15.939003,-12.939003,-8.939003,-21.939003,-36.939003,-56.939003,-57.939003,-58.939003,-58.939003,-57.939003,-55.939003,-56.939003,-58.939003,-60.939003,-58.939003,-57.939003,-56.939003,-56.939003,-56.939003,-56.939003,-53.939003,-49.939003,-50.939003,-52.939003,-53.939003,-55.939003,-56.939003,-55.939003,-54.939003,-52.939003,-50.939003,-49.939003,-51.939003,-53.939003,-56.939003,-57.939003,-57.939003,-58.939003,-59.939003,-59.939003,-56.939003,-57.939003,-58.939003,-47.939003,-34.939003,-21.939003,-29.939003,-41.939003,-57.939003,-63.939003,-64.939,-65.939,-65.939,-66.939,-68.939,-69.939,-70.939,-69.939,-68.939,-71.939,-76.939,-83.939,-84.939,-82.939,-75.939,-77.939,-83.939,-94.939,-98.939,-94.939,-95.939,-95.939,-92.939,-94.939,-95.939,-62.939003,-42.939003,-33.939003,-40.939003,-43.939003,-36.939003,-34.939003,-36.939003,-38.939003,-39.939003,-38.939003,-32.939003,-32.939003,-46.939003,-55.939003,-62.939003,-74.939,-81.939,-83.939,-77.939,-71.939,-65.939,-62.939003,-61.939003,-53.939003,-46.939003,-40.939003,-49.939003,-54.939003,-50.939003,-44.939003,-37.939003,-38.939003,-39.939003,-39.939003,-46.939003,-51.939003,-50.939003,-44.939003,-37.939003,-42.939003,-45.939003,-46.939003,-45.939003,-45.939003,-48.939003,-49.939003,-50.939003,-45.939003,-46.939003,-50.939003,-59.939003,-65.939,-60.939003,-62.939003,-67.939,-73.939,-78.939,-81.939,-79.939,-78.939,21.060997,36.060997,54.060997,44.060997,35.060997,29.060997,24.060997,20.060997,14.060997,7.060997,1.060997,-8.939003,-14.939003,-15.939003,-21.939003,-24.939003,-19.939003,-10.939003,-2.939003,-1.939003,1.060997,4.060997,1.060997,-0.939003,-0.939003,-4.939003,-9.939003,-11.939003,-13.939003,-14.939003,-23.939003,-30.939003,-33.939003,-40.939003,-49.939003,-53.939003,-58.939003,-64.939,-71.939,-77.939,-83.939,-90.939,-94.939,-85.939,-59.939003,-20.939003,-52.939003,-85.939,-99.939,-103.939,-103.939,-102.939,-101.939,-101.939,-101.939,-100.939,-99.939,-98.939,-97.939,-96.939,-95.939,-93.939,-89.939,-84.939,-86.939,-83.939,-80.939,-79.939,-75.939,-72.939,-70.939,-67.939,-66.939,-59.939003,-52.939003,-48.939003,-44.939003,-37.939003,-31.939003,-24.939003,-15.939003,-14.939003,-16.939003,-22.939003,-27.939003,-31.939003,-32.939003,-26.939003,-8.939003,2.060997,9.060997,8.060997,8.060997,7.060997,-0.939003,-8.939003,-16.939003,-19.939003,-22.939003,-31.939003,-38.939003,-44.939003,-49.939003,-55.939003,-62.939003,-69.939,-74.939,-78.939,-84.939,-93.939,-99.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-101.939,-101.939,-99.939,-97.939,-97.939,-96.939,-94.939,-93.939,-91.939,-89.939,-88.939,-86.939,-83.939,-81.939,-79.939,-73.939,-71.939,-70.939,-61.939003,-54.939003,-49.939003,-45.939003,-41.939003,-32.939003,-26.939003,-23.939003,-14.939003,-10.939003,-10.939003,-18.939003,-24.939003,-19.939003,-9.939003,2.060997,9.060997,11.060997,7.060997,5.060997,1.060997,-1.939003,-3.939003,-5.939003,-13.939003,-19.939003,-24.939003,-31.939003,-37.939003,-40.939003,-44.939003,-48.939003,-50.939003,-51.939003,-53.939003,-56.939003,-59.939003,-62.939003,-55.939003,-44.939003,-32.939003,-25.939003,-22.939003,-12.939003,-3.939003,2.060997,9.060997,17.060997,28.060997,38.060997,48.060997,57.060997,64.061,66.061,71.061,74.061,73.061,74.061,79.061,74.061,70.061,69.061,61.060997,50.060997,44.060997,36.060997,29.060997,24.060997,16.060997,0.06099701,-6.939003,-9.939003,-10.939003,-15.939003,-23.939003,-25.939003,-25.939003,-25.939003,-25.939003,-25.939003,-26.939003,-26.939003,-24.939003,-25.939003,-25.939003,-23.939003,-18.939003,-12.939003,-10.939003,-8.939003,-4.939003,6.060997,15.060997,19.060997,23.060997,26.060997,29.060997,33.060997,36.060997,43.060997,49.060997,54.060997,58.060997,62.060997,62.060997,62.060997,62.060997,68.061,66.061,45.060997,37.060997,36.060997,40.060997,29.060997,4.060997,-3.939003,-7.939003,-7.939003,1.060997,13.060997,26.060997,28.060997,18.060997,16.060997,2.060997,-43.939003,-76.939,-101.939,-96.939,-92.939,-91.939,-90.939,-83.939,-68.939,-60.939003,-56.939003,-50.939003,-44.939003,-37.939003,-32.939003,-26.939003,-19.939003,-14.939003,-9.939003,-5.939003,-0.939003,6.060997,5.060997,6.060997,12.060997,13.060997,12.060997,11.060997,9.060997,5.060997,2.060997,-5.939003,-25.939003,-36.939003,-44.939003,-56.939003,-45.939003,-12.939003,8.060997,25.060997,33.060997,31.060997,24.060997,17.060997,0.06099701,-26.939003,-71.939,-99.939,-86.939,-91.939,-100.939,-65.939,-16.939003,48.060997,73.061,73.061,23.060997,-38.939003,-98.939,-94.939,-92.939,-91.939,-89.939,-88.939,-88.939,-87.939,-86.939,-84.939,-83.939,-83.939,-80.939,-76.939,-71.939,-65.939,-59.939003,-55.939003,-50.939003,-45.939003,-41.939003,-36.939003,-28.939003,-24.939003,-20.939003,-15.939003,-10.939003,-8.939003,-7.939003,-6.939003,0.06099701,-0.939003,-3.939003,-3.939003,1.060997,9.060997,2.060997,-1.939003,6.060997,3.060997,-1.939003,-5.939003,-10.939003,-14.939003,-21.939003,-27.939003,-32.939003,-10.939003,19.060997,26.060997,25.060997,12.060997,-32.939003,-69.939,-77.939,-81.939,-84.939,-91.939,-96.939,-97.939,-99.939,-99.939,-98.939,-97.939,-97.939,-96.939,-95.939,-93.939,-92.939,-91.939,-90.939,-88.939,-86.939,-86.939,-85.939,-84.939,-81.939,-78.939,-73.939,-68.939,-62.939003,-58.939003,-53.939003,-46.939003,-43.939003,-39.939003,-32.939003,-24.939003,-19.939003,-19.939003,-16.939003,-8.939003,-15.939003,-20.939003,-21.939003,-15.939003,-9.939003,-21.939003,-32.939003,-38.939003,-34.939003,-32.939003,-37.939003,-34.939003,-30.939003,-26.939003,-22.939003,-18.939003,-12.939003,-8.939003,-15.939003,-14.939003,-8.939003,1.060997,14.060997,31.060997,26.060997,23.060997,28.060997,33.060997,36.060997,40.060997,41.060997,36.060997,48.060997,39.060997,-22.939003,-49.939003,-56.939003,-33.939003,-19.939003,-15.939003,-5.939003,1.060997,-1.939003,-15.939003,-26.939003,15.060997,24.060997,4.060997,-5.939003,-19.939003,-41.939003,-6.939003,44.060997,18.060997,-2.939003,-19.939003,13.060997,33.060997,9.060997,18.060997,40.060997,38.060997,35.060997,30.060997,40.060997,48.060997,44.060997,39.060997,33.060997,31.060997,34.060997,42.060997,37.060997,27.060997,5.060997,12.060997,27.060997,22.060997,16.060997,11.060997,10.060997,10.060997,8.060997,11.060997,15.060997,-1.939003,-14.939003,-24.939003,-31.939003,-36.939003,-40.939003,-49.939003,-61.939003,-72.939,-82.939,-90.939,-90.939,-89.939,-87.939,-86.939,-86.939,-84.939,-81.939,-78.939,-74.939,-69.939,-61.939003,-55.939003,-50.939003,-50.939003,-48.939003,-45.939003,-39.939003,-34.939003,-32.939003,-28.939003,-26.939003,-25.939003,-23.939003,-20.939003,-17.939003,-16.939003,-20.939003,-19.939003,-17.939003,-21.939003,-23.939003,-27.939003,-28.939003,-30.939003,-34.939003,-31.939003,-26.939003,-10.939003,-4.939003,-4.939003,-36.939003,-66.939,-76.939,-83.939,-88.939,-93.939,-96.939,-96.939,-97.939,-97.939,-94.939,-91.939,-89.939,-91.939,-90.939,-88.939,-84.939,-82.939,-82.939,-80.939,-77.939,-74.939,-70.939,-65.939,-59.939003,-53.939003,-49.939003,-44.939003,-39.939003,-39.939003,-37.939003,-34.939003,-35.939003,-35.939003,-36.939003,-33.939003,-29.939003,-44.939003,-50.939003,-47.939003,-40.939003,-37.939003,-44.939003,-40.939003,-33.939003,-31.939003,-29.939003,-28.939003,-30.939003,-30.939003,-20.939003,-12.939003,-5.939003,-1.939003,2.060997,3.060997,8.060997,12.060997,13.060997,20.060997,28.060997,18.060997,14.060997,17.060997,15.060997,0.06099701,-45.939003,-54.939003,-49.939003,-50.939003,-50.939003,-48.939003,-48.939003,-46.939003,-42.939003,-41.939003,-39.939003,-29.939003,-24.939003,-23.939003,-24.939003,-24.939003,-17.939003,-16.939003,-17.939003,-17.939003,-25.939003,-40.939003,-19.939003,9.060997,48.060997,41.060997,19.060997,30.060997,30.060997,20.060997,18.060997,22.060997,36.060997,38.060997,36.060997,40.060997,40.060997,37.060997,10.060997,-10.939003,-14.939003,8.060997,38.060997,31.060997,28.060997,26.060997,28.060997,26.060997,20.060997,27.060997,35.060997,-0.939003,-40.939003,-87.939,-87.939,-84.939,-85.939,-83.939,-81.939,-79.939,-78.939,-79.939,-77.939,-75.939,-72.939,-70.939,-67.939,-68.939,-63.939003,-51.939003,-54.939003,-57.939003,-57.939003,-57.939003,-57.939003,-53.939003,-51.939003,-51.939003,-45.939003,-41.939003,-42.939003,-45.939003,-49.939003,-48.939003,-48.939003,-49.939003,-46.939003,-44.939003,-41.939003,-40.939003,-40.939003,-49.939003,-49.939003,-39.939003,-44.939003,-47.939003,-44.939003,-44.939003,-43.939003,-44.939003,-45.939003,-45.939003,-47.939003,-50.939003,-51.939003,-49.939003,-47.939003,-51.939003,-59.939003,-73.939,-83.939,-87.939,-79.939,-79.939,-84.939,-95.939,-99.939,-94.939,-94.939,-94.939,-92.939,-93.939,-93.939,-73.939,-58.939003,-47.939003,-46.939003,-42.939003,-33.939003,-31.939003,-34.939003,-37.939003,-39.939003,-41.939003,-41.939003,-44.939003,-50.939003,-59.939003,-67.939,-77.939,-82.939,-79.939,-75.939,-72.939,-73.939,-70.939,-66.939,-56.939003,-42.939003,-24.939003,-31.939003,-38.939003,-39.939003,-36.939003,-33.939003,-33.939003,-37.939003,-47.939003,-53.939003,-56.939003,-52.939003,-45.939003,-39.939003,-41.939003,-44.939003,-49.939003,-50.939003,-49.939003,-49.939003,-49.939003,-50.939003,-46.939003,-47.939003,-49.939003,-58.939003,-65.939,-64.939,-63.939003,-64.939,-72.939,-77.939,-79.939,-79.939,-78.939,-14.939003,-16.939003,-20.939003,-20.939003,-25.939003,-34.939003,-33.939003,-33.939003,-36.939003,-34.939003,-31.939003,-29.939003,-25.939003,-21.939003,-17.939003,-15.939003,-17.939003,-23.939003,-30.939003,-32.939003,-37.939003,-46.939003,-55.939003,-61.939003,-61.939003,-67.939,-75.939,-76.939,-78.939,-79.939,-79.939,-81.939,-85.939,-89.939,-91.939,-90.939,-90.939,-89.939,-93.939,-95.939,-97.939,-98.939,-99.939,-96.939,-70.939,-21.939003,-47.939003,-77.939,-97.939,-103.939,-102.939,-98.939,-96.939,-96.939,-92.939,-88.939,-84.939,-81.939,-78.939,-73.939,-67.939,-62.939003,-50.939003,-42.939003,-41.939003,-33.939003,-23.939003,-21.939003,-17.939003,-10.939003,-5.939003,-2.939003,-4.939003,-2.939003,1.060997,6.060997,7.060997,4.060997,2.060997,0.06099701,-0.939003,-1.939003,-0.939003,2.060997,7.060997,13.060997,15.060997,4.060997,-37.939003,-52.939003,-55.939003,-56.939003,-60.939003,-66.939,-71.939,-77.939,-82.939,-84.939,-84.939,-86.939,-88.939,-89.939,-91.939,-92.939,-94.939,-95.939,-96.939,-97.939,-99.939,-101.939,-102.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-103.939,-101.939,-98.939,-97.939,-95.939,-92.939,-84.939,-79.939,-76.939,-71.939,-66.939,-59.939003,-52.939003,-44.939003,-38.939003,-31.939003,-25.939003,-16.939003,-9.939003,-1.939003,0.06099701,1.060997,3.060997,6.060997,11.060997,9.060997,7.060997,1.060997,-1.939003,-4.939003,-6.939003,-2.939003,7.060997,9.060997,11.060997,11.060997,-15.939003,-52.939003,-57.939003,-63.939003,-69.939,-69.939,-67.939,-66.939,-62.939003,-56.939003,-45.939003,-40.939003,-39.939003,-34.939003,-26.939003,-15.939003,-6.939003,4.060997,15.060997,26.060997,37.060997,43.060997,48.060997,48.060997,55.060997,63.060997,65.061,69.061,72.061,69.061,66.061,61.060997,61.060997,60.060997,59.060997,52.060997,40.060997,33.060997,27.060997,19.060997,12.060997,6.060997,-1.939003,-7.939003,-12.939003,-14.939003,-17.939003,-18.939003,-20.939003,-20.939003,-19.939003,-16.939003,-11.939003,-10.939003,-12.939003,-14.939003,-15.939003,-14.939003,-12.939003,-11.939003,-11.939003,-7.939003,-3.939003,-0.939003,6.060997,12.060997,12.060997,15.060997,21.060997,28.060997,34.060997,37.060997,16.060997,-11.939003,7.060997,20.060997,27.060997,56.060997,76.061,71.061,65.061,61.060997,68.061,69.061,67.061,70.061,73.061,75.061,74.061,73.061,77.061,75.061,67.061,69.061,68.061,60.060997,11.060997,-51.939003,-63.939003,-72.939,-78.939,-80.939,-68.939,-25.939003,-7.939003,1.060997,27.060997,33.060997,16.060997,-1.939003,-23.939003,-55.939003,-78.939,-95.939,-84.939,-77.939,-76.939,-55.939003,-31.939003,-6.939003,2.060997,4.060997,1.060997,2.060997,4.060997,1.060997,-1.939003,-3.939003,-4.939003,-6.939003,-10.939003,-15.939003,-20.939003,-25.939003,-28.939003,-30.939003,-35.939003,-42.939003,-47.939003,-54.939003,-63.939003,-65.939,-69.939,-80.939,-86.939,-89.939,-92.939,-73.939,-32.939003,-13.939003,3.060997,24.060997,22.060997,9.060997,-10.939003,-31.939003,-51.939003,-80.939,-100.939,-98.939,-100.939,-102.939,-93.939,-62.939003,-10.939003,34.060997,52.060997,8.060997,-39.939003,-81.939,-66.939,-56.939003,-50.939003,-43.939003,-38.939003,-40.939003,-35.939003,-28.939003,-19.939003,-14.939003,-15.939003,-11.939003,-6.939003,-1.939003,2.060997,5.060997,6.060997,7.060997,9.060997,9.060997,7.060997,3.060997,4.060997,5.060997,3.060997,-1.939003,-6.939003,-9.939003,-13.939003,-20.939003,-22.939003,-22.939003,-38.939003,-47.939003,-47.939003,-55.939003,-64.939,-70.939,-71.939,-71.939,-77.939,-81.939,-83.939,-84.939,-82.939,-73.939,-2.939003,84.061,81.061,70.061,51.060997,-31.939003,-95.939,-97.939,-98.939,-99.939,-99.939,-97.939,-96.939,-91.939,-87.939,-83.939,-79.939,-76.939,-72.939,-66.939,-59.939003,-56.939003,-53.939003,-47.939003,-40.939003,-31.939003,-31.939003,-27.939003,-20.939003,-15.939003,-11.939003,-8.939003,-6.939003,-3.939003,-1.939003,-0.939003,1.060997,-2.939003,-4.939003,-4.939003,-3.939003,-3.939003,-10.939003,-15.939003,-19.939003,-23.939003,-25.939003,-22.939003,-26.939003,-31.939003,-16.939003,-0.939003,16.060997,20.060997,21.060997,24.060997,28.060997,33.060997,35.060997,38.060997,41.060997,49.060997,46.060997,16.060997,4.060997,2.060997,29.060997,47.060997,56.060997,44.060997,33.060997,30.060997,28.060997,28.060997,25.060997,16.060997,0.06099701,10.060997,6.060997,-40.939003,-55.939003,-55.939003,-33.939003,-19.939003,-12.939003,-10.939003,-5.939003,10.060997,5.060997,-6.939003,16.060997,26.060997,22.060997,10.060997,-7.939003,-40.939003,-28.939003,1.060997,24.060997,28.060997,13.060997,23.060997,25.060997,2.060997,7.060997,21.060997,21.060997,26.060997,38.060997,47.060997,49.060997,42.060997,35.060997,30.060997,30.060997,30.060997,30.060997,30.060997,32.060997,33.060997,31.060997,27.060997,22.060997,18.060997,14.060997,12.060997,6.060997,-3.939003,-3.939003,-1.939003,-22.939003,-29.939003,-23.939003,-27.939003,-34.939003,-44.939003,-58.939003,-70.939,-68.939,-61.939003,-48.939003,-45.939003,-44.939003,-40.939003,-35.939003,-28.939003,-24.939003,-20.939003,-15.939003,-14.939003,-14.939003,-13.939003,-10.939003,-5.939003,-9.939003,-13.939003,-19.939003,-18.939003,-18.939003,-23.939003,-26.939003,-28.939003,-34.939003,-39.939003,-45.939003,-48.939003,-52.939003,-61.939003,-67.939,-69.939,-73.939,-78.939,-85.939,-86.939,-86.939,-87.939,-60.939003,-19.939003,30.060997,43.060997,17.060997,-41.939003,-87.939,-95.939,-94.939,-89.939,-87.939,-82.939,-75.939,-76.939,-75.939,-65.939,-53.939003,-43.939003,-49.939003,-47.939003,-36.939003,-33.939003,-30.939003,-29.939003,-26.939003,-22.939003,-21.939003,-20.939003,-20.939003,-20.939003,-20.939003,-20.939003,-26.939003,-33.939003,-38.939003,-40.939003,-37.939003,-36.939003,-37.939003,-44.939003,-44.939003,-40.939003,-20.939003,-9.939003,-8.939003,-15.939003,-19.939003,-5.939003,1.060997,5.060997,9.060997,1.060997,-16.939003,-17.939003,-9.939003,11.060997,16.060997,14.060997,9.060997,5.060997,2.060997,-2.939003,-5.939003,-9.939003,-10.939003,-13.939003,-26.939003,-34.939003,-35.939003,-36.939003,-38.939003,-44.939003,-39.939003,-30.939003,-30.939003,-28.939003,-25.939003,-23.939003,-24.939003,-30.939003,-36.939003,-40.939003,-32.939003,-26.939003,-23.939003,-22.939003,-20.939003,-18.939003,-17.939003,-16.939003,-19.939003,-24.939003,-32.939003,-28.939003,-23.939003,-18.939003,-24.939003,-32.939003,-22.939003,-20.939003,-28.939003,-27.939003,-22.939003,-13.939003,-11.939003,-11.939003,-13.939003,-9.939003,1.060997,-12.939003,-24.939003,-26.939003,-9.939003,10.060997,5.060997,4.060997,8.060997,13.060997,15.060997,11.060997,19.060997,28.060997,-9.939003,-51.939003,-96.939,-99.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-98.939,-97.939,-96.939,-96.939,-95.939,-95.939,-95.939,-84.939,-64.939,-79.939,-92.939,-92.939,-92.939,-92.939,-91.939,-90.939,-90.939,-88.939,-87.939,-89.939,-90.939,-90.939,-90.939,-90.939,-90.939,-90.939,-89.939,-89.939,-88.939,-81.939,-28.939003,9.060997,31.060997,-0.939003,-39.939003,-78.939,-82.939,-74.939,-75.939,-75.939,-73.939,-72.939,-71.939,-70.939,-70.939,-69.939,-67.939,-70.939,-78.939,-84.939,-86.939,-80.939,-79.939,-79.939,-94.939,-97.939,-92.939,-89.939,-89.939,-93.939,-95.939,-92.939,-76.939,-63.939003,-53.939003,-49.939003,-43.939003,-29.939003,-29.939003,-34.939003,-40.939003,-41.939003,-37.939003,-48.939003,-55.939003,-53.939003,-58.939003,-66.939,-72.939,-74.939,-73.939,-67.939,-65.939,-72.939,-71.939,-67.939,-61.939003,-45.939003,-19.939003,-25.939003,-31.939003,-31.939003,-31.939003,-31.939003,-34.939003,-39.939003,-43.939003,-45.939003,-48.939003,-52.939003,-50.939003,-48.939003,-48.939003,-49.939003,-50.939003,-51.939003,-50.939003,-47.939003,-48.939003,-50.939003,-51.939003,-53.939003,-54.939003,-60.939003,-63.939003,-63.939003,-61.939003,-59.939003,-70.939,-76.939,-78.939,-80.939,-83.939,-14.939003,-12.939003,-10.939003,-8.939003,-9.939003,-11.939003,-6.939003,-2.939003,-1.939003,1.060997,5.060997,9.060997,12.060997,15.060997,25.060997,22.060997,-11.939003,-41.939003,-67.939,-68.939,-72.939,-79.939,-86.939,-89.939,-89.939,-93.939,-98.939,-98.939,-99.939,-100.939,-98.939,-97.939,-99.939,-99.939,-98.939,-96.939,-92.939,-86.939,-86.939,-84.939,-81.939,-79.939,-78.939,-76.939,-58.939003,-25.939003,-35.939003,-50.939003,-65.939,-66.939,-60.939003,-54.939003,-50.939003,-48.939003,-43.939003,-38.939003,-35.939003,-33.939003,-31.939003,-28.939003,-26.939003,-24.939003,-20.939003,-17.939003,-17.939003,-15.939003,-13.939003,-14.939003,-16.939003,-18.939003,-18.939003,-19.939003,-24.939003,-27.939003,-29.939003,-27.939003,-29.939003,-34.939003,-37.939003,-40.939003,-42.939003,-23.939003,7.060997,40.060997,56.060997,52.060997,57.060997,37.060997,-45.939003,-77.939,-86.939,-86.939,-88.939,-93.939,-96.939,-100.939,-102.939,-103.939,-103.939,-103.939,-102.939,-100.939,-99.939,-98.939,-97.939,-96.939,-93.939,-91.939,-87.939,-82.939,-78.939,-75.939,-72.939,-68.939,-64.939,-61.939003,-58.939003,-54.939003,-51.939003,-47.939003,-44.939003,-41.939003,-39.939003,-34.939003,-30.939003,-27.939003,-26.939003,-25.939003,-23.939003,-23.939003,-23.939003,-19.939003,-18.939003,-18.939003,-19.939003,-19.939003,-18.939003,-18.939003,-18.939003,-19.939003,-21.939003,-22.939003,-27.939003,-30.939003,-29.939003,-34.939003,-39.939003,-20.939003,7.060997,46.060997,50.060997,50.060997,49.060997,15.060997,-29.939003,-28.939003,-29.939003,-29.939003,-26.939003,-20.939003,-10.939003,-2.939003,5.060997,14.060997,18.060997,19.060997,23.060997,27.060997,29.060997,32.060997,36.060997,40.060997,43.060997,47.060997,47.060997,45.060997,40.060997,41.060997,44.060997,39.060997,36.060997,34.060997,31.060997,27.060997,21.060997,21.060997,23.060997,22.060997,18.060997,12.060997,8.060997,5.060997,1.060997,-1.939003,-3.939003,-4.939003,-6.939003,-10.939003,-8.939003,-7.939003,-7.939003,-8.939003,-7.939003,-11.939003,-13.939003,-12.939003,-12.939003,-12.939003,-11.939003,-12.939003,-13.939003,-8.939003,0.06099701,14.060997,30.060997,41.060997,39.060997,43.060997,49.060997,49.060997,52.060997,56.060997,59.060997,63.060997,67.061,33.060997,-9.939003,15.060997,29.060997,32.060997,63.060997,84.061,78.061,68.061,58.060997,60.060997,59.060997,56.060997,56.060997,54.060997,50.060997,45.060997,41.060997,41.060997,37.060997,30.060997,29.060997,28.060997,26.060997,-6.939003,-49.939003,-63.939003,-73.939,-80.939,-88.939,-84.939,-57.939003,-48.939003,-46.939003,-23.939003,-18.939003,-34.939003,-48.939003,-62.939003,-73.939,-80.939,-86.939,-85.939,-80.939,-71.939,-51.939003,-32.939003,-23.939003,-23.939003,-28.939003,-34.939003,-35.939003,-35.939003,-38.939003,-41.939003,-44.939003,-46.939003,-48.939003,-51.939003,-56.939003,-61.939003,-64.939,-67.939,-69.939,-73.939,-77.939,-81.939,-86.939,-91.939,-92.939,-93.939,-96.939,-96.939,-94.939,-88.939,-77.939,-62.939003,-55.939003,-46.939003,-32.939003,-33.939003,-40.939003,-55.939003,-68.939,-78.939,-85.939,-87.939,-77.939,-68.939,-59.939003,-56.939003,-44.939003,-25.939003,-5.939003,4.060997,-9.939003,-21.939003,-29.939003,-23.939003,-17.939003,-11.939003,-9.939003,-8.939003,-8.939003,-7.939003,-7.939003,-6.939003,-6.939003,-8.939003,-11.939003,-13.939003,-15.939003,-14.939003,-13.939003,-18.939003,-23.939003,-25.939003,-29.939003,-33.939003,-37.939003,-38.939003,-37.939003,-40.939003,-44.939003,-48.939003,-41.939003,-41.939003,-57.939003,-60.939003,-56.939003,-70.939,-78.939,-81.939,-86.939,-91.939,-93.939,-93.939,-93.939,-95.939,-94.939,-91.939,-89.939,-83.939,-67.939,-8.939003,58.060997,47.060997,36.060997,24.060997,-27.939003,-64.939,-58.939003,-57.939003,-56.939003,-52.939003,-49.939003,-46.939003,-42.939003,-38.939003,-35.939003,-31.939003,-30.939003,-29.939003,-26.939003,-22.939003,-22.939003,-22.939003,-20.939003,-19.939003,-17.939003,-17.939003,-17.939003,-17.939003,-18.939003,-19.939003,-20.939003,-22.939003,-25.939003,-27.939003,-31.939003,-35.939003,-39.939003,-42.939003,-43.939003,-43.939003,-44.939003,-49.939003,-53.939003,-57.939003,-43.939003,-29.939003,-22.939003,-33.939003,-49.939003,-21.939003,8.060997,40.060997,37.060997,32.060997,32.060997,33.060997,32.060997,29.060997,28.060997,26.060997,31.060997,29.060997,4.060997,-6.939003,-10.939003,5.060997,13.060997,14.060997,4.060997,-2.939003,-5.939003,-7.939003,-10.939003,-11.939003,-16.939003,-25.939003,-17.939003,-20.939003,-51.939003,-58.939003,-53.939003,-31.939003,-18.939003,-13.939003,-11.939003,-7.939003,-0.939003,-6.939003,-15.939003,6.060997,20.060997,29.060997,19.060997,3.060997,-22.939003,-21.939003,-6.939003,19.060997,24.060997,8.060997,6.060997,2.060997,-7.939003,-2.939003,8.060997,10.060997,22.060997,43.060997,48.060997,47.060997,40.060997,35.060997,31.060997,31.060997,30.060997,28.060997,26.060997,28.060997,36.060997,34.060997,27.060997,5.060997,-3.939003,1.060997,12.060997,14.060997,-8.939003,-13.939003,-13.939003,-32.939003,-30.939003,-9.939003,-19.939003,-33.939003,-47.939003,-62.939003,-75.939,-63.939003,-47.939003,-25.939003,-23.939003,-24.939003,-25.939003,-24.939003,-23.939003,-28.939003,-29.939003,-27.939003,-30.939003,-33.939003,-36.939003,-38.939003,-38.939003,-42.939003,-46.939003,-51.939003,-51.939003,-52.939003,-57.939003,-59.939003,-61.939003,-65.939,-70.939,-75.939,-77.939,-79.939,-81.939,-81.939,-80.939,-81.939,-82.939,-83.939,-80.939,-77.939,-76.939,-56.939003,-27.939003,2.060997,7.060997,-11.939003,-39.939003,-60.939003,-60.939003,-57.939003,-53.939003,-52.939003,-49.939003,-44.939003,-43.939003,-42.939003,-38.939003,-34.939003,-30.939003,-33.939003,-34.939003,-30.939003,-32.939003,-33.939003,-34.939003,-37.939003,-39.939003,-40.939003,-42.939003,-45.939003,-47.939003,-50.939003,-51.939003,-50.939003,-48.939003,-44.939003,-42.939003,-40.939003,-39.939003,-40.939003,-49.939003,-52.939003,-49.939003,-15.939003,0.06099701,-2.939003,-16.939003,-23.939003,-9.939003,-5.939003,-7.939003,-6.939003,-13.939003,-29.939003,-31.939003,-28.939003,-20.939003,-19.939003,-20.939003,-26.939003,-28.939003,-28.939003,-29.939003,-30.939003,-34.939003,-36.939003,-39.939003,-43.939003,-45.939003,-48.939003,-44.939003,-42.939003,-45.939003,-40.939003,-31.939003,-31.939003,-31.939003,-32.939003,-30.939003,-29.939003,-33.939003,-38.939003,-41.939003,-34.939003,-29.939003,-23.939003,-21.939003,-19.939003,-19.939003,-17.939003,-15.939003,-19.939003,-23.939003,-31.939003,-22.939003,-17.939003,-23.939003,-30.939003,-37.939003,-32.939003,-32.939003,-36.939003,-36.939003,-34.939003,-31.939003,-32.939003,-34.939003,-34.939003,-32.939003,-26.939003,-31.939003,-34.939003,-35.939003,-29.939003,-24.939003,-28.939003,-29.939003,-25.939003,-22.939003,-19.939003,-21.939003,-17.939003,-12.939003,-27.939003,-46.939003,-66.939,-65.939,-64.939,-66.939,-68.939,-70.939,-72.939,-72.939,-71.939,-71.939,-72.939,-74.939,-74.939,-73.939,-75.939,-68.939,-54.939003,-66.939,-76.939,-78.939,-78.939,-78.939,-82.939,-82.939,-80.939,-81.939,-83.939,-85.939,-86.939,-88.939,-88.939,-89.939,-90.939,-93.939,-96.939,-95.939,-94.939,-87.939,-24.939003,17.060997,39.060997,1.060997,-43.939003,-89.939,-98.939,-91.939,-92.939,-92.939,-90.939,-89.939,-89.939,-88.939,-88.939,-87.939,-86.939,-86.939,-90.939,-88.939,-85.939,-83.939,-81.939,-78.939,-92.939,-97.939,-92.939,-89.939,-89.939,-94.939,-96.939,-94.939,-71.939,-53.939003,-38.939003,-43.939003,-45.939003,-36.939003,-39.939003,-48.939003,-50.939003,-48.939003,-42.939003,-51.939003,-57.939003,-52.939003,-55.939003,-61.939003,-63.939003,-63.939003,-60.939003,-58.939003,-59.939003,-68.939,-70.939,-69.939,-56.939003,-40.939003,-21.939003,-23.939003,-28.939003,-28.939003,-32.939003,-37.939003,-35.939003,-35.939003,-35.939003,-37.939003,-40.939003,-44.939003,-47.939003,-49.939003,-49.939003,-47.939003,-45.939003,-48.939003,-48.939003,-43.939003,-44.939003,-47.939003,-50.939003,-53.939003,-56.939003,-60.939003,-63.939003,-66.939,-63.939003,-58.939003,-67.939,-73.939,-75.939,-80.939,-84.939,0.06099701,14.060997,34.060997,34.060997,38.060997,44.060997,52.060997,59.060997,64.061,66.061,67.061,67.061,67.061,67.061,79.061,67.061,-3.939003,-58.939003,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-100.939,-98.939,-97.939,-98.939,-96.939,-94.939,-91.939,-89.939,-86.939,-83.939,-77.939,-68.939,-64.939,-59.939003,-50.939003,-47.939003,-45.939003,-42.939003,-36.939003,-29.939003,-22.939003,-18.939003,-24.939003,-18.939003,-8.939003,-1.939003,3.060997,7.060997,12.060997,15.060997,15.060997,15.060997,14.060997,12.060997,8.060997,3.060997,-3.939003,-8.939003,-10.939003,-18.939003,-28.939003,-33.939003,-42.939003,-55.939003,-62.939003,-67.939,-76.939,-83.939,-90.939,-93.939,-96.939,-102.939,-101.939,-101.939,-102.939,-57.939003,8.060997,78.061,103.061,83.061,92.061,66.061,-45.939003,-88.939,-99.939,-96.939,-95.939,-96.939,-97.939,-98.939,-98.939,-98.939,-97.939,-97.939,-94.939,-90.939,-87.939,-85.939,-83.939,-79.939,-74.939,-70.939,-61.939003,-50.939003,-42.939003,-36.939003,-30.939003,-22.939003,-13.939003,-9.939003,-3.939003,5.060997,8.060997,13.060997,16.060997,18.060997,19.060997,18.060997,18.060997,20.060997,14.060997,7.060997,3.060997,-6.939003,-16.939003,-17.939003,-23.939003,-32.939003,-46.939003,-57.939003,-64.939,-65.939,-65.939,-68.939,-73.939,-80.939,-85.939,-85.939,-71.939,-75.939,-80.939,-34.939003,21.060997,87.061,87.061,83.061,84.061,57.060997,23.060997,33.060997,42.060997,48.060997,54.060997,62.060997,79.061,89.061,96.061,94.061,92.061,91.061,89.061,83.061,70.061,60.060997,50.060997,41.060997,31.060997,21.060997,11.060997,1.060997,-11.939003,-16.939003,-19.939003,-28.939003,-37.939003,-45.939003,-44.939003,-43.939003,-46.939003,-42.939003,-36.939003,-31.939003,-25.939003,-17.939003,-12.939003,-8.939003,-4.939003,0.06099701,7.060997,16.060997,21.060997,20.060997,26.060997,30.060997,30.060997,29.060997,27.060997,15.060997,4.060997,-4.939003,-6.939003,-8.939003,-8.939003,-9.939003,-10.939003,-3.939003,14.060997,42.060997,70.061,88.061,77.061,75.061,78.061,78.061,78.061,80.061,75.061,72.061,77.061,40.060997,-6.939003,16.060997,25.060997,22.060997,47.060997,65.061,61.060997,48.060997,34.060997,31.060997,29.060997,26.060997,23.060997,18.060997,6.060997,0.06099701,-5.939003,-11.939003,-15.939003,-17.939003,-20.939003,-22.939003,-21.939003,-16.939003,-12.939003,-20.939003,-30.939003,-44.939003,-60.939003,-75.939,-90.939,-97.939,-102.939,-89.939,-88.939,-98.939,-101.939,-101.939,-92.939,-84.939,-78.939,-92.939,-90.939,-71.939,-61.939003,-58.939003,-72.939,-83.939,-95.939,-100.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-98.939,-97.939,-97.939,-97.939,-95.939,-94.939,-93.939,-89.939,-84.939,-76.939,-63.939003,-68.939,-91.939,-98.939,-103.939,-102.939,-101.939,-100.939,-102.939,-103.939,-103.939,-87.939,-68.939,-45.939003,-23.939003,-4.939003,0.06099701,-4.939003,-20.939003,-37.939003,-43.939003,-23.939003,-0.939003,22.060997,11.060997,10.060997,14.060997,8.060997,4.060997,5.060997,0.06099701,-7.939003,-20.939003,-27.939003,-30.939003,-41.939003,-52.939003,-64.939,-65.939,-65.939,-76.939,-85.939,-94.939,-99.939,-103.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-81.939,-73.939,-93.939,-95.939,-89.939,-93.939,-95.939,-97.939,-97.939,-95.939,-86.939,-85.939,-87.939,-82.939,-76.939,-70.939,-65.939,-57.939003,-39.939003,-19.939003,-2.939003,-17.939003,-24.939003,-24.939003,-21.939003,-14.939003,-2.939003,1.060997,1.060997,7.060997,11.060997,12.060997,13.060997,13.060997,14.060997,15.060997,13.060997,10.060997,7.060997,4.060997,0.06099701,-4.939003,-8.939003,-17.939003,-24.939003,-27.939003,-33.939003,-42.939003,-50.939003,-58.939003,-63.939003,-69.939,-77.939,-84.939,-91.939,-100.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-62.939003,-31.939003,-22.939003,-38.939003,-61.939003,-29.939003,6.060997,44.060997,34.060997,21.060997,15.060997,10.060997,3.060997,-3.939003,-10.939003,-16.939003,-16.939003,-16.939003,-21.939003,-26.939003,-31.939003,-35.939003,-40.939003,-47.939003,-49.939003,-49.939003,-48.939003,-49.939003,-53.939003,-48.939003,-43.939003,-39.939003,-34.939003,-36.939003,-58.939003,-59.939003,-51.939003,-28.939003,-16.939003,-16.939003,-9.939003,-7.939003,-18.939003,-28.939003,-33.939003,-11.939003,8.060997,26.060997,21.060997,13.060997,0.06099701,-1.939003,0.06099701,8.060997,5.060997,-10.939003,-16.939003,-17.939003,-14.939003,-7.939003,2.060997,5.060997,20.060997,46.060997,45.060997,41.060997,38.060997,36.060997,34.060997,32.060997,30.060997,28.060997,23.060997,21.060997,28.060997,29.060997,25.060997,-18.939003,-30.939003,-13.939003,14.060997,25.060997,-9.939003,-20.939003,-23.939003,-36.939003,-26.939003,7.060997,-13.939003,-35.939003,-49.939003,-65.939,-78.939,-61.939003,-41.939003,-18.939003,-18.939003,-21.939003,-29.939003,-35.939003,-43.939003,-59.939003,-68.939,-69.939,-75.939,-80.939,-86.939,-91.939,-97.939,-98.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-96.939,-93.939,-84.939,-76.939,-70.939,-66.939,-61.939003,-54.939003,-47.939003,-41.939003,-40.939003,-40.939003,-40.939003,-48.939003,-53.939003,-54.939003,-36.939003,-19.939003,-12.939003,-9.939003,-10.939003,-15.939003,-17.939003,-18.939003,-14.939003,-14.939003,-21.939003,-30.939003,-36.939003,-34.939003,-37.939003,-46.939003,-54.939003,-61.939003,-65.939,-73.939,-82.939,-87.939,-90.939,-93.939,-97.939,-99.939,-100.939,-87.939,-67.939,-51.939003,-43.939003,-43.939003,-41.939003,-42.939003,-50.939003,-55.939003,-55.939003,-20.939003,-6.939003,-13.939003,-28.939003,-36.939003,-29.939003,-31.939003,-38.939003,-42.939003,-45.939003,-49.939003,-52.939003,-58.939003,-71.939,-72.939,-68.939,-71.939,-68.939,-61.939003,-54.939003,-50.939003,-52.939003,-51.939003,-49.939003,-42.939003,-38.939003,-40.939003,-32.939003,-31.939003,-47.939003,-48.939003,-42.939003,-41.939003,-43.939003,-50.939003,-47.939003,-44.939003,-43.939003,-42.939003,-40.939003,-36.939003,-30.939003,-23.939003,-21.939003,-19.939003,-20.939003,-17.939003,-13.939003,-18.939003,-24.939003,-33.939003,-11.939003,5.060997,1.060997,-8.939003,-20.939003,-20.939003,-22.939003,-24.939003,-25.939003,-27.939003,-28.939003,-33.939003,-39.939003,-35.939003,-36.939003,-42.939003,-42.939003,-40.939003,-38.939003,-43.939003,-51.939003,-55.939003,-56.939003,-54.939003,-54.939003,-54.939003,-53.939003,-55.939003,-56.939003,-45.939003,-37.939003,-29.939003,-22.939003,-18.939003,-23.939003,-27.939003,-32.939003,-36.939003,-36.939003,-33.939003,-34.939003,-36.939003,-38.939003,-38.939003,-37.939003,-40.939003,-39.939003,-36.939003,-39.939003,-43.939003,-47.939003,-47.939003,-47.939003,-55.939003,-56.939003,-52.939003,-55.939003,-58.939003,-60.939003,-62.939003,-67.939,-66.939,-68.939,-71.939,-78.939,-83.939,-82.939,-81.939,-76.939,-29.939003,1.060997,14.060997,-15.939003,-50.939003,-87.939,-98.939,-98.939,-99.939,-98.939,-97.939,-98.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-101.939,-91.939,-84.939,-86.939,-83.939,-79.939,-91.939,-96.939,-93.939,-91.939,-91.939,-96.939,-98.939,-97.939,-64.939,-38.939003,-18.939003,-35.939003,-49.939003,-48.939003,-55.939003,-65.939,-61.939003,-57.939003,-52.939003,-55.939003,-56.939003,-50.939003,-51.939003,-56.939003,-55.939003,-52.939003,-47.939003,-50.939003,-54.939003,-61.939003,-67.939,-71.939,-50.939003,-35.939003,-25.939003,-24.939003,-26.939003,-29.939003,-36.939003,-43.939003,-34.939003,-29.939003,-26.939003,-30.939003,-33.939003,-34.939003,-40.939003,-47.939003,-47.939003,-44.939003,-39.939003,-43.939003,-45.939003,-40.939003,-42.939003,-45.939003,-46.939003,-50.939003,-57.939003,-59.939003,-63.939003,-71.939,-66.939,-58.939003,-65.939,-70.939,-73.939,-78.939,-83.939,16.060997,39.060997,68.061,65.061,65.061,67.061,70.061,72.061,73.061,73.061,72.061,70.061,72.061,75.061,87.061,76.061,9.060997,-51.939003,-102.939,-93.939,-89.939,-90.939,-89.939,-86.939,-83.939,-80.939,-75.939,-68.939,-62.939003,-57.939003,-54.939003,-51.939003,-44.939003,-36.939003,-27.939003,-26.939003,-21.939003,-13.939003,-12.939003,-10.939003,-5.939003,-5.939003,-6.939003,-3.939003,-8.939003,-20.939003,-18.939003,-17.939003,-19.939003,-22.939003,-25.939003,-26.939003,-29.939003,-33.939003,-36.939003,-39.939003,-42.939003,-47.939003,-53.939003,-58.939003,-62.939003,-67.939,-70.939,-71.939,-71.939,-73.939,-78.939,-80.939,-82.939,-86.939,-87.939,-89.939,-93.939,-95.939,-96.939,-99.939,-101.939,-103.939,-100.939,-100.939,-102.939,-70.939,-19.939003,60.060997,88.061,62.060997,69.061,49.060997,-38.939003,-75.939,-88.939,-77.939,-68.939,-61.939003,-60.939003,-57.939003,-51.939003,-46.939003,-40.939003,-33.939003,-28.939003,-26.939003,-22.939003,-16.939003,-10.939003,-6.939003,-5.939003,-5.939003,-6.939003,-7.939003,-6.939003,-5.939003,-5.939003,-10.939003,-18.939003,-21.939003,-23.939003,-23.939003,-27.939003,-31.939003,-40.939003,-44.939003,-48.939003,-49.939003,-51.939003,-54.939003,-59.939003,-61.939003,-55.939003,-51.939003,-49.939003,-46.939003,-43.939003,-37.939003,-33.939003,-28.939003,-22.939003,-12.939003,-2.939003,4.060997,10.060997,15.060997,19.060997,24.060997,35.060997,37.060997,38.060997,45.060997,51.060997,56.060997,57.060997,58.060997,59.060997,46.060997,29.060997,30.060997,31.060997,28.060997,22.060997,15.060997,9.060997,7.060997,7.060997,2.060997,-0.939003,-0.939003,-1.939003,-4.939003,-8.939003,-10.939003,-11.939003,-11.939003,-9.939003,-6.939003,-5.939003,-2.939003,3.060997,4.060997,5.060997,11.060997,15.060997,15.060997,12.060997,11.060997,13.060997,16.060997,17.060997,15.060997,13.060997,10.060997,4.060997,0.06099701,-4.939003,-9.939003,-14.939003,-18.939003,-27.939003,-39.939003,-44.939003,-49.939003,-55.939003,-58.939003,-59.939003,-63.939003,-67.939,-70.939,-71.939,-67.939,-55.939003,-40.939003,-23.939003,-14.939003,3.060997,29.060997,55.060997,69.061,56.060997,49.060997,47.060997,43.060997,40.060997,37.060997,31.060997,26.060997,24.060997,7.060997,-11.939003,-6.939003,-5.939003,-8.939003,-4.939003,-1.939003,-1.939003,-8.939003,-16.939003,-15.939003,-13.939003,-11.939003,-14.939003,-16.939003,-18.939003,-14.939003,-10.939003,-13.939003,-12.939003,-10.939003,-4.939003,-1.939003,-4.939003,-1.939003,2.060997,14.060997,21.060997,22.060997,16.060997,-8.939003,-77.939,-98.939,-99.939,-94.939,-95.939,-101.939,-102.939,-101.939,-95.939,-93.939,-95.939,-97.939,-97.939,-92.939,-89.939,-88.939,-93.939,-96.939,-100.939,-100.939,-101.939,-103.939,-101.939,-98.939,-97.939,-95.939,-93.939,-89.939,-85.939,-81.939,-73.939,-66.939,-63.939003,-59.939003,-55.939003,-50.939003,-43.939003,-34.939003,-29.939003,-26.939003,-24.939003,-21.939003,-16.939003,-6.939003,-25.939003,-72.939,-91.939,-101.939,-93.939,-94.939,-99.939,-101.939,-102.939,-103.939,-90.939,-66.939,-23.939003,-14.939003,-17.939003,-12.939003,-4.939003,6.060997,7.060997,7.060997,3.060997,-25.939003,-61.939003,-63.939003,-64.939,-64.939,-66.939,-67.939,-67.939,-68.939,-71.939,-75.939,-78.939,-79.939,-82.939,-86.939,-90.939,-90.939,-90.939,-94.939,-97.939,-100.939,-102.939,-101.939,-95.939,-91.939,-89.939,-86.939,-80.939,-73.939,-58.939003,-51.939003,-58.939003,-55.939003,-47.939003,-44.939003,-40.939003,-37.939003,-37.939003,-35.939003,-26.939003,-20.939003,-16.939003,-13.939003,-11.939003,-9.939003,-8.939003,-6.939003,-4.939003,-14.939003,-26.939003,-27.939003,-23.939003,-15.939003,-10.939003,-7.939003,-10.939003,-12.939003,-15.939003,-17.939003,-21.939003,-26.939003,-31.939003,-37.939003,-41.939003,-46.939003,-53.939003,-58.939003,-63.939003,-67.939,-68.939,-70.939,-71.939,-74.939,-77.939,-78.939,-80.939,-83.939,-85.939,-88.939,-90.939,-91.939,-92.939,-96.939,-96.939,-92.939,-92.939,-91.939,-89.939,-85.939,-80.939,-74.939,-69.939,-62.939003,-43.939003,-28.939003,-28.939003,-29.939003,-30.939003,-24.939003,-18.939003,-14.939003,-19.939003,-25.939003,-30.939003,-33.939003,-35.939003,-36.939003,-34.939003,-30.939003,-33.939003,-32.939003,-24.939003,-26.939003,-31.939003,-28.939003,-25.939003,-24.939003,-22.939003,-20.939003,-17.939003,-13.939003,-9.939003,-4.939003,1.060997,8.060997,17.060997,8.060997,-45.939003,-56.939003,-48.939003,-27.939003,-15.939003,-13.939003,-9.939003,-6.939003,-6.939003,-14.939003,-24.939003,-36.939003,-38.939003,-31.939003,-6.939003,10.060997,0.06099701,5.060997,14.060997,0.06099701,-7.939003,-9.939003,17.060997,37.060997,28.060997,27.060997,28.060997,28.060997,33.060997,41.060997,36.060997,32.060997,39.060997,37.060997,33.060997,33.060997,24.060997,4.060997,16.060997,23.060997,13.060997,11.060997,10.060997,-23.939003,-22.939003,12.060997,25.060997,23.060997,-4.939003,-25.939003,-41.939003,-37.939003,-28.939003,-13.939003,-29.939003,-43.939003,-48.939003,-63.939003,-79.939,-78.939,-76.939,-74.939,-75.939,-76.939,-78.939,-80.939,-83.939,-88.939,-91.939,-92.939,-93.939,-91.939,-85.939,-86.939,-88.939,-80.939,-75.939,-73.939,-69.939,-64.939,-62.939003,-57.939003,-51.939003,-46.939003,-42.939003,-38.939003,-37.939003,-33.939003,-25.939003,-21.939003,-19.939003,-20.939003,-20.939003,-18.939003,-15.939003,-15.939003,-20.939003,-27.939003,-33.939003,-31.939003,-26.939003,-19.939003,-32.939003,-44.939003,-49.939003,-51.939003,-51.939003,-62.939003,-68.939,-72.939,-72.939,-73.939,-76.939,-79.939,-81.939,-79.939,-80.939,-84.939,-87.939,-89.939,-90.939,-92.939,-94.939,-95.939,-92.939,-87.939,-82.939,-78.939,-73.939,-64.939,-52.939003,-44.939003,-40.939003,-42.939003,-42.939003,-41.939003,-40.939003,-40.939003,-40.939003,-43.939003,-44.939003,-44.939003,-41.939003,-39.939003,-43.939003,-43.939003,-43.939003,-41.939003,-39.939003,-36.939003,-35.939003,-33.939003,-33.939003,-31.939003,-27.939003,-19.939003,-13.939003,-9.939003,-7.939003,-5.939003,-3.939003,3.060997,10.060997,0.06099701,-1.939003,0.06099701,-4.939003,-18.939003,-52.939003,-58.939003,-53.939003,-51.939003,-48.939003,-47.939003,-47.939003,-46.939003,-47.939003,-44.939003,-39.939003,-31.939003,-25.939003,-18.939003,-21.939003,-22.939003,-20.939003,-18.939003,-16.939003,-15.939003,-22.939003,-38.939003,-7.939003,24.060997,46.060997,39.060997,22.060997,27.060997,25.060997,18.060997,21.060997,23.060997,26.060997,25.060997,21.060997,21.060997,22.060997,20.060997,-4.939003,-21.939003,-13.939003,2.060997,18.060997,11.060997,8.060997,9.060997,10.060997,8.060997,4.060997,6.060997,7.060997,-25.939003,-54.939003,-76.939,-73.939,-68.939,-69.939,-68.939,-67.939,-68.939,-67.939,-66.939,-63.939003,-61.939003,-60.939003,-59.939003,-59.939003,-57.939003,-52.939003,-45.939003,-49.939003,-52.939003,-49.939003,-51.939003,-54.939003,-52.939003,-50.939003,-48.939003,-48.939003,-48.939003,-48.939003,-45.939003,-42.939003,-45.939003,-49.939003,-52.939003,-50.939003,-48.939003,-49.939003,-49.939003,-49.939003,-45.939003,-43.939003,-42.939003,-43.939003,-44.939003,-48.939003,-52.939003,-55.939003,-56.939003,-57.939003,-59.939003,-57.939003,-57.939003,-60.939003,-62.939003,-62.939003,-57.939003,-64.939,-85.939,-92.939,-94.939,-86.939,-80.939,-77.939,-93.939,-97.939,-88.939,-92.939,-96.939,-100.939,-101.939,-98.939,-65.939,-40.939003,-23.939003,-45.939003,-63.939003,-63.939003,-64.939,-66.939,-66.939,-68.939,-72.939,-74.939,-71.939,-61.939003,-55.939003,-52.939003,-51.939003,-51.939003,-52.939003,-46.939003,-45.939003,-52.939003,-60.939003,-67.939,-62.939003,-50.939003,-32.939003,-29.939003,-29.939003,-36.939003,-36.939003,-33.939003,-29.939003,-28.939003,-30.939003,-35.939003,-39.939003,-39.939003,-44.939003,-50.939003,-48.939003,-45.939003,-42.939003,-48.939003,-53.939003,-50.939003,-51.939003,-52.939003,-51.939003,-51.939003,-54.939003,-57.939003,-62.939003,-68.939,-68.939,-65.939,-64.939,-66.939,-70.939,-69.939,-70.939,21.060997,48.060997,81.061,80.061,77.061,72.061,72.061,71.061,68.061,65.061,60.060997,56.060997,53.060997,51.060997,59.060997,52.060997,5.060997,-37.939003,-72.939,-59.939003,-52.939003,-51.939003,-49.939003,-47.939003,-45.939003,-41.939003,-36.939003,-31.939003,-26.939003,-23.939003,-21.939003,-19.939003,-15.939003,-11.939003,-8.939003,-9.939003,-8.939003,-7.939003,-8.939003,-9.939003,-11.939003,-11.939003,-13.939003,-17.939003,-21.939003,-22.939003,-22.939003,-26.939003,-40.939003,-49.939003,-55.939003,-58.939003,-63.939003,-69.939,-74.939,-78.939,-81.939,-86.939,-92.939,-96.939,-100.939,-103.939,-103.939,-102.939,-102.939,-101.939,-101.939,-101.939,-100.939,-98.939,-95.939,-93.939,-94.939,-92.939,-88.939,-85.939,-83.939,-83.939,-79.939,-76.939,-74.939,-56.939003,-28.939003,25.060997,41.060997,19.060997,21.060997,12.060997,-28.939003,-45.939003,-48.939003,-36.939003,-28.939003,-23.939003,-23.939003,-23.939003,-20.939003,-17.939003,-13.939003,-10.939003,-9.939003,-11.939003,-10.939003,-8.939003,-6.939003,-7.939003,-9.939003,-10.939003,-14.939003,-19.939003,-21.939003,-21.939003,-21.939003,-26.939003,-33.939003,-37.939003,-39.939003,-39.939003,-38.939003,-38.939003,-43.939003,-45.939003,-46.939003,-40.939003,-38.939003,-39.939003,-41.939003,-39.939003,-27.939003,-20.939003,-14.939003,-10.939003,-5.939003,2.060997,10.060997,17.060997,23.060997,29.060997,36.060997,40.060997,43.060997,48.060997,51.060997,52.060997,53.060997,54.060997,55.060997,45.060997,33.060997,17.060997,17.060997,17.060997,16.060997,13.060997,10.060997,10.060997,8.060997,4.060997,0.06099701,-5.939003,-15.939003,-17.939003,-17.939003,-20.939003,-21.939003,-18.939003,-19.939003,-20.939003,-19.939003,-15.939003,-11.939003,-9.939003,-8.939003,-7.939003,-6.939003,-4.939003,3.060997,2.060997,0.06099701,8.060997,12.060997,11.060997,3.060997,-2.939003,-1.939003,-1.939003,-1.939003,-5.939003,-9.939003,-15.939003,-23.939003,-29.939003,-35.939003,-42.939003,-49.939003,-57.939003,-66.939,-77.939,-82.939,-87.939,-90.939,-90.939,-89.939,-86.939,-84.939,-81.939,-78.939,-71.939,-57.939003,-42.939003,-27.939003,-17.939003,-6.939003,6.060997,20.060997,27.060997,17.060997,12.060997,9.060997,5.060997,2.060997,-0.939003,-3.939003,-5.939003,-7.939003,-10.939003,-12.939003,-13.939003,-14.939003,-16.939003,-14.939003,-12.939003,-10.939003,-12.939003,-16.939003,-10.939003,-5.939003,-2.939003,-4.939003,-3.939003,1.060997,8.060997,14.060997,13.060997,16.060997,19.060997,29.060997,33.060997,23.060997,19.060997,21.060997,47.060997,61.060997,61.060997,68.061,46.060997,-47.939003,-86.939,-99.939,-93.939,-92.939,-98.939,-100.939,-100.939,-95.939,-93.939,-92.939,-86.939,-84.939,-85.939,-81.939,-78.939,-80.939,-77.939,-73.939,-71.939,-68.939,-66.939,-62.939003,-58.939003,-58.939003,-57.939003,-54.939003,-49.939003,-44.939003,-39.939003,-31.939003,-24.939003,-24.939003,-22.939003,-19.939003,-16.939003,-11.939003,-4.939003,-1.939003,-0.939003,-1.939003,-2.939003,-3.939003,-0.939003,-21.939003,-65.939,-88.939,-99.939,-88.939,-88.939,-92.939,-98.939,-101.939,-100.939,-93.939,-77.939,-40.939003,-38.939003,-47.939003,-29.939003,0.06099701,43.060997,47.060997,40.060997,23.060997,-35.939003,-103.939,-101.939,-100.939,-102.939,-100.939,-98.939,-95.939,-92.939,-90.939,-88.939,-86.939,-85.939,-83.939,-81.939,-79.939,-76.939,-73.939,-70.939,-69.939,-68.939,-66.939,-62.939003,-52.939003,-49.939003,-50.939003,-45.939003,-39.939003,-32.939003,-29.939003,-27.939003,-24.939003,-21.939003,-18.939003,-15.939003,-10.939003,-6.939003,-10.939003,-12.939003,-8.939003,-4.939003,-2.939003,-5.939003,-7.939003,-7.939003,-10.939003,-13.939003,-15.939003,-8.939003,2.060997,5.060997,8.060997,11.060997,-13.939003,-35.939003,-42.939003,-45.939003,-48.939003,-53.939003,-58.939003,-63.939003,-68.939,-73.939,-78.939,-84.939,-90.939,-93.939,-95.939,-96.939,-94.939,-92.939,-90.939,-87.939,-84.939,-84.939,-82.939,-79.939,-78.939,-76.939,-74.939,-71.939,-68.939,-68.939,-64.939,-54.939003,-52.939003,-52.939003,-51.939003,-47.939003,-41.939003,-34.939003,-30.939003,-26.939003,-24.939003,-24.939003,-26.939003,-20.939003,-12.939003,-18.939003,-24.939003,-32.939003,-30.939003,-29.939003,-32.939003,-32.939003,-29.939003,-25.939003,-19.939003,-13.939003,-12.939003,-11.939003,-11.939003,-13.939003,-13.939003,-3.939003,6.060997,16.060997,15.060997,15.060997,15.060997,19.060997,24.060997,24.060997,26.060997,32.060997,41.060997,28.060997,-38.939003,-54.939003,-45.939003,-26.939003,-14.939003,-9.939003,-6.939003,-2.939003,1.060997,-1.939003,-8.939003,-26.939003,-36.939003,-37.939003,-8.939003,10.060997,1.060997,3.060997,7.060997,-1.939003,-1.939003,11.060997,36.060997,52.060997,44.060997,40.060997,39.060997,38.060997,37.060997,39.060997,33.060997,30.060997,34.060997,35.060997,34.060997,38.060997,22.060997,-16.939003,7.060997,23.060997,6.060997,-2.939003,-9.939003,-17.939003,-4.939003,27.060997,27.060997,16.060997,-9.939003,-30.939003,-46.939003,-41.939003,-35.939003,-26.939003,-36.939003,-46.939003,-50.939003,-63.939003,-80.939,-85.939,-87.939,-86.939,-84.939,-81.939,-79.939,-78.939,-77.939,-75.939,-74.939,-74.939,-71.939,-65.939,-56.939003,-55.939003,-57.939003,-49.939003,-45.939003,-42.939003,-39.939003,-37.939003,-34.939003,-30.939003,-26.939003,-24.939003,-23.939003,-21.939003,-23.939003,-23.939003,-20.939003,-21.939003,-22.939003,-25.939003,-28.939003,-32.939003,-31.939003,-34.939003,-39.939003,-33.939003,-20.939003,1.060997,9.060997,5.060997,-36.939003,-71.939,-79.939,-81.939,-79.939,-85.939,-88.939,-90.939,-89.939,-88.939,-86.939,-85.939,-85.939,-81.939,-78.939,-77.939,-77.939,-76.939,-72.939,-72.939,-72.939,-70.939,-65.939,-59.939003,-55.939003,-52.939003,-47.939003,-44.939003,-40.939003,-38.939003,-37.939003,-41.939003,-41.939003,-40.939003,-38.939003,-36.939003,-35.939003,-38.939003,-40.939003,-40.939003,-37.939003,-33.939003,-31.939003,-29.939003,-27.939003,-22.939003,-22.939003,-25.939003,-22.939003,-15.939003,-5.939003,-3.939003,-5.939003,0.06099701,3.060997,3.060997,2.060997,1.060997,2.060997,7.060997,11.060997,0.06099701,-5.939003,-4.939003,-12.939003,-24.939003,-47.939003,-50.939003,-45.939003,-43.939003,-40.939003,-37.939003,-36.939003,-36.939003,-42.939003,-42.939003,-38.939003,-29.939003,-22.939003,-16.939003,-19.939003,-21.939003,-20.939003,-18.939003,-17.939003,-15.939003,-21.939003,-36.939003,-13.939003,10.060997,30.060997,25.060997,13.060997,18.060997,17.060997,12.060997,14.060997,17.060997,23.060997,26.060997,25.060997,23.060997,24.060997,27.060997,-0.939003,-20.939003,-9.939003,11.060997,32.060997,27.060997,24.060997,24.060997,25.060997,24.060997,21.060997,27.060997,29.060997,-21.939003,-64.939,-97.939,-96.939,-91.939,-92.939,-90.939,-88.939,-89.939,-89.939,-89.939,-86.939,-83.939,-81.939,-80.939,-80.939,-77.939,-69.939,-57.939003,-65.939,-71.939,-67.939,-69.939,-72.939,-67.939,-64.939,-64.939,-63.939003,-62.939003,-61.939003,-57.939003,-53.939003,-56.939003,-59.939003,-62.939003,-57.939003,-53.939003,-54.939003,-53.939003,-49.939003,-35.939003,-29.939003,-31.939003,-41.939003,-48.939003,-49.939003,-51.939003,-54.939003,-53.939003,-54.939003,-56.939003,-53.939003,-52.939003,-54.939003,-56.939003,-56.939003,-49.939003,-57.939003,-81.939,-95.939,-99.939,-86.939,-77.939,-74.939,-91.939,-95.939,-87.939,-93.939,-98.939,-99.939,-95.939,-87.939,-66.939,-53.939003,-47.939003,-59.939003,-68.939,-67.939,-68.939,-68.939,-65.939,-69.939,-79.939,-80.939,-76.939,-63.939003,-56.939003,-54.939003,-53.939003,-51.939003,-46.939003,-41.939003,-40.939003,-47.939003,-51.939003,-53.939003,-63.939003,-56.939003,-36.939003,-28.939003,-26.939003,-35.939003,-34.939003,-30.939003,-28.939003,-28.939003,-29.939003,-34.939003,-38.939003,-41.939003,-46.939003,-51.939003,-49.939003,-47.939003,-48.939003,-52.939003,-55.939003,-53.939003,-53.939003,-55.939003,-54.939003,-53.939003,-53.939003,-56.939003,-61.939003,-69.939,-71.939,-69.939,-67.939,-67.939,-70.939,-68.939,-66.939,16.060997,41.060997,74.061,77.061,73.061,61.060997,57.060997,54.060997,48.060997,40.060997,31.060997,23.060997,11.060997,-3.939003,-3.939003,-5.939003,-17.939003,-17.939003,-11.939003,-0.939003,8.060997,15.060997,17.060997,16.060997,11.060997,13.060997,16.060997,12.060997,8.060997,2.060997,1.060997,-0.939003,-4.939003,-15.939003,-27.939003,-32.939003,-39.939003,-49.939003,-52.939003,-56.939003,-66.939,-66.939,-66.939,-83.939,-75.939,-37.939003,-35.939003,-47.939003,-86.939,-98.939,-99.939,-99.939,-99.939,-100.939,-100.939,-101.939,-101.939,-101.939,-102.939,-102.939,-103.939,-102.939,-103.939,-103.939,-103.939,-100.939,-96.939,-96.939,-95.939,-93.939,-85.939,-78.939,-78.939,-74.939,-67.939,-52.939003,-44.939003,-42.939003,-36.939003,-29.939003,-19.939003,-16.939003,-18.939003,-27.939003,-35.939003,-43.939003,-50.939003,-45.939003,-17.939003,3.060997,20.060997,25.060997,24.060997,18.060997,11.060997,3.060997,-4.939003,-10.939003,-16.939003,-29.939003,-38.939003,-45.939003,-51.939003,-59.939003,-72.939,-80.939,-84.939,-85.939,-87.939,-86.939,-87.939,-86.939,-80.939,-70.939,-59.939003,-56.939003,-50.939003,-42.939003,-24.939003,-7.939003,6.060997,16.060997,24.060997,44.060997,58.060997,66.061,69.061,74.061,85.061,88.061,88.061,91.061,90.061,88.061,85.061,80.061,72.061,62.060997,51.060997,37.060997,26.060997,18.060997,10.060997,-0.939003,-15.939003,-24.939003,-29.939003,-33.939003,-33.939003,-28.939003,-32.939003,-37.939003,-44.939003,-40.939003,-32.939003,-27.939003,-24.939003,-22.939003,-12.939003,-1.939003,5.060997,14.060997,21.060997,24.060997,29.060997,36.060997,35.060997,34.060997,37.060997,43.060997,49.060997,45.060997,35.060997,17.060997,6.060997,-3.939003,-11.939003,-23.939003,-34.939003,-38.939003,-46.939003,-58.939003,-71.939,-84.939,-92.939,-94.939,-94.939,-94.939,-94.939,-95.939,-96.939,-96.939,-97.939,-97.939,-98.939,-99.939,-97.939,-92.939,-88.939,-83.939,-73.939,-66.939,-61.939003,-52.939003,-45.939003,-36.939003,-27.939003,-18.939003,-12.939003,-15.939003,-20.939003,-12.939003,-15.939003,-26.939003,-34.939003,-38.939003,-37.939003,-36.939003,-35.939003,-36.939003,-36.939003,-35.939003,-29.939003,-23.939003,-20.939003,-14.939003,-7.939003,-5.939003,-2.939003,-0.939003,17.060997,31.060997,35.060997,35.060997,34.060997,47.060997,53.060997,52.060997,54.060997,58.060997,66.061,69.061,70.061,69.061,71.061,72.061,82.061,82.061,62.060997,48.060997,43.060997,79.061,89.061,72.061,94.061,87.061,-2.939003,-61.939003,-101.939,-85.939,-80.939,-88.939,-95.939,-98.939,-92.939,-82.939,-70.939,-59.939003,-51.939003,-49.939003,-37.939003,-29.939003,-34.939003,-26.939003,-14.939003,-12.939003,-5.939003,6.060997,13.060997,17.060997,14.060997,13.060997,14.060997,19.060997,22.060997,23.060997,25.060997,25.060997,17.060997,13.060997,11.060997,5.060997,-0.939003,-7.939003,-10.939003,-14.939003,-22.939003,-29.939003,-38.939003,-45.939003,-56.939003,-71.939,-89.939,-98.939,-87.939,-81.939,-79.939,-95.939,-100.939,-94.939,-98.939,-101.939,-97.939,-97.939,-93.939,-50.939003,11.060997,92.061,80.061,57.060997,36.060997,-29.939003,-103.939,-100.939,-99.939,-100.939,-95.939,-88.939,-79.939,-71.939,-65.939,-58.939003,-53.939003,-49.939003,-43.939003,-37.939003,-31.939003,-22.939003,-12.939003,-5.939003,-1.939003,2.060997,8.060997,14.060997,25.060997,23.060997,16.060997,20.060997,22.060997,21.060997,6.060997,-0.939003,10.060997,5.060997,-3.939003,-8.939003,-7.939003,-4.939003,-15.939003,-26.939003,-33.939003,-39.939003,-47.939003,-58.939003,-63.939003,-66.939,-73.939,-77.939,-71.939,-2.939003,84.061,80.061,71.061,57.060997,-29.939003,-97.939,-97.939,-98.939,-98.939,-98.939,-99.939,-98.939,-96.939,-95.939,-97.939,-98.939,-97.939,-93.939,-88.939,-81.939,-76.939,-70.939,-65.939,-56.939003,-47.939003,-46.939003,-41.939003,-31.939003,-27.939003,-22.939003,-17.939003,-10.939003,-3.939003,-1.939003,4.060997,13.060997,16.060997,16.060997,12.060997,12.060997,14.060997,18.060997,15.060997,9.060997,-5.939003,-17.939003,-16.939003,-12.939003,-6.939003,-9.939003,-10.939003,-11.939003,0.06099701,9.060997,8.060997,12.060997,19.060997,29.060997,34.060997,33.060997,46.060997,47.060997,17.060997,12.060997,21.060997,38.060997,56.060997,75.061,66.061,57.060997,52.060997,50.060997,48.060997,38.060997,32.060997,32.060997,37.060997,23.060997,-39.939003,-53.939003,-44.939003,-25.939003,-12.939003,-5.939003,-0.939003,2.060997,5.060997,9.060997,15.060997,18.060997,16.060997,9.060997,14.060997,14.060997,3.060997,-9.939003,-19.939003,1.060997,25.060997,53.060997,41.060997,29.060997,31.060997,33.060997,34.060997,33.060997,34.060997,38.060997,38.060997,35.060997,25.060997,28.060997,35.060997,47.060997,24.060997,-34.939003,-3.939003,21.060997,9.060997,-13.939003,-34.939003,0.06099701,22.060997,32.060997,21.060997,3.060997,-22.939003,-33.939003,-38.939003,-49.939003,-46.939003,-30.939003,-35.939003,-44.939003,-53.939003,-66.939,-80.939,-82.939,-73.939,-54.939003,-45.939003,-36.939003,-32.939003,-28.939003,-24.939003,-20.939003,-16.939003,-15.939003,-8.939003,-2.939003,1.060997,-0.939003,-2.939003,-6.939003,-8.939003,-7.939003,-12.939003,-16.939003,-15.939003,-18.939003,-24.939003,-33.939003,-40.939003,-45.939003,-54.939003,-62.939003,-68.939,-74.939,-79.939,-80.939,-86.939,-95.939,-96.939,-97.939,-97.939,-58.939003,-3.939003,48.060997,56.060997,20.060997,-49.939003,-100.939,-101.939,-98.939,-94.939,-84.939,-77.939,-73.939,-66.939,-59.939003,-51.939003,-50.939003,-48.939003,-38.939003,-31.939003,-26.939003,-26.939003,-22.939003,-10.939003,-12.939003,-17.939003,-13.939003,-10.939003,-8.939003,-15.939003,-21.939003,-22.939003,-27.939003,-31.939003,-33.939003,-36.939003,-41.939003,-39.939003,-39.939003,-43.939003,-44.939003,-40.939003,-7.939003,4.060997,-2.939003,-14.939003,-17.939003,5.060997,11.060997,10.060997,14.060997,5.060997,-16.939003,-13.939003,-4.939003,12.060997,9.060997,-2.939003,-11.939003,-17.939003,-21.939003,-26.939003,-31.939003,-35.939003,-40.939003,-45.939003,-44.939003,-48.939003,-57.939003,-55.939003,-48.939003,-32.939003,-23.939003,-16.939003,-17.939003,-17.939003,-18.939003,-13.939003,-14.939003,-28.939003,-35.939003,-38.939003,-29.939003,-22.939003,-17.939003,-16.939003,-16.939003,-19.939003,-19.939003,-17.939003,-18.939003,-21.939003,-27.939003,-29.939003,-34.939003,-46.939003,-49.939003,-49.939003,-48.939003,-47.939003,-42.939003,-46.939003,-47.939003,-38.939003,-30.939003,-26.939003,-31.939003,-30.939003,-21.939003,-30.939003,-35.939003,-26.939003,-16.939003,-8.939003,-7.939003,-8.939003,-10.939003,-9.939003,-7.939003,-2.939003,6.060997,10.060997,-33.939003,-69.939,-94.939,-91.939,-88.939,-93.939,-95.939,-95.939,-99.939,-101.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-90.939,-71.939,-87.939,-100.939,-100.939,-100.939,-100.939,-100.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-98.939,-99.939,-99.939,-99.939,-99.939,-98.939,-99.939,-92.939,-76.939,1.060997,43.060997,48.060997,-9.939003,-63.939003,-90.939,-96.939,-94.939,-89.939,-88.939,-88.939,-86.939,-83.939,-81.939,-81.939,-81.939,-75.939,-78.939,-89.939,-98.939,-99.939,-87.939,-76.939,-69.939,-86.939,-92.939,-88.939,-95.939,-98.939,-93.939,-80.939,-66.939,-69.939,-76.939,-89.939,-77.939,-65.939,-62.939003,-67.939,-73.939,-60.939003,-60.939003,-72.939,-75.939,-71.939,-54.939003,-54.939003,-61.939003,-62.939003,-52.939003,-31.939003,-35.939003,-41.939003,-46.939003,-39.939003,-30.939003,-52.939003,-53.939003,-36.939003,-23.939003,-17.939003,-27.939003,-31.939003,-33.939003,-33.939003,-29.939003,-23.939003,-25.939003,-30.939003,-40.939003,-46.939003,-50.939003,-50.939003,-52.939003,-56.939003,-53.939003,-50.939003,-48.939003,-50.939003,-53.939003,-55.939003,-56.939003,-55.939003,-56.939003,-60.939003,-73.939,-74.939,-71.939,-72.939,-73.939,-74.939,-73.939,-73.939,0.06099701,8.060997,17.060997,15.060997,11.060997,2.060997,-0.939003,-3.939003,-8.939003,-10.939003,-12.939003,-17.939003,-21.939003,-24.939003,-20.939003,-17.939003,-18.939003,-17.939003,-13.939003,-11.939003,-12.939003,-14.939003,-18.939003,-24.939003,-31.939003,-34.939003,-35.939003,-43.939003,-48.939003,-51.939003,-53.939003,-55.939003,-58.939003,-64.939,-71.939,-73.939,-76.939,-81.939,-82.939,-83.939,-89.939,-88.939,-86.939,-96.939,-82.939,-46.939003,-34.939003,-41.939003,-86.939,-101.939,-102.939,-97.939,-94.939,-94.939,-91.939,-88.939,-86.939,-81.939,-76.939,-75.939,-71.939,-65.939,-61.939003,-57.939003,-55.939003,-52.939003,-47.939003,-39.939003,-34.939003,-34.939003,-25.939003,-18.939003,-17.939003,-15.939003,-12.939003,-7.939003,-6.939003,-6.939003,-11.939003,-13.939003,-10.939003,-14.939003,-19.939003,-7.939003,-4.939003,-8.939003,-2.939003,-0.939003,-12.939003,-29.939003,-45.939003,-37.939003,-36.939003,-43.939003,-48.939003,-51.939003,-52.939003,-47.939003,-39.939003,-37.939003,-36.939003,-37.939003,-35.939003,-31.939003,-27.939003,-24.939003,-20.939003,-12.939003,-5.939003,1.060997,5.060997,10.060997,16.060997,26.060997,34.060997,33.060997,32.060997,31.060997,35.060997,38.060997,39.060997,39.060997,39.060997,41.060997,41.060997,39.060997,33.060997,30.060997,30.060997,26.060997,22.060997,16.060997,13.060997,12.060997,8.060997,4.060997,1.060997,0.06099701,1.060997,-0.939003,-1.939003,-2.939003,-4.939003,-4.939003,-1.939003,1.060997,4.060997,2.060997,0.06099701,-3.939003,-0.939003,1.060997,3.060997,1.060997,-0.939003,4.060997,5.060997,0.06099701,0.06099701,-0.939003,-0.939003,-0.939003,-0.939003,-9.939003,-15.939003,-19.939003,-27.939003,-34.939003,-36.939003,-35.939003,-33.939003,-35.939003,-40.939003,-49.939003,-54.939003,-59.939003,-64.939,-69.939,-75.939,-77.939,-81.939,-86.939,-87.939,-87.939,-86.939,-84.939,-82.939,-75.939,-69.939,-64.939,-59.939003,-55.939003,-52.939003,-47.939003,-42.939003,-39.939003,-35.939003,-28.939003,-25.939003,-23.939003,-18.939003,-15.939003,-11.939003,-4.939003,-1.939003,-1.939003,2.060997,5.060997,0.06099701,-6.939003,-13.939003,-11.939003,-9.939003,-6.939003,-1.939003,2.060997,3.060997,6.060997,10.060997,14.060997,17.060997,19.060997,22.060997,29.060997,43.060997,25.060997,-1.939003,14.060997,23.060997,25.060997,49.060997,69.061,73.061,67.061,59.060997,65.061,68.061,69.061,70.061,70.061,71.061,70.061,69.061,64.061,60.060997,57.060997,57.060997,53.060997,38.060997,26.060997,19.060997,42.060997,45.060997,28.060997,37.060997,33.060997,-14.939003,-59.939003,-98.939,-86.939,-82.939,-88.939,-96.939,-99.939,-88.939,-51.939003,-7.939003,-5.939003,-4.939003,-4.939003,-0.939003,0.06099701,-6.939003,-5.939003,-1.939003,-3.939003,-5.939003,-7.939003,-9.939003,-12.939003,-16.939003,-19.939003,-23.939003,-26.939003,-30.939003,-35.939003,-38.939003,-43.939003,-48.939003,-50.939003,-51.939003,-54.939003,-57.939003,-61.939003,-63.939003,-64.939,-68.939,-72.939,-76.939,-80.939,-84.939,-89.939,-93.939,-91.939,-76.939,-67.939,-63.939003,-90.939,-100.939,-94.939,-90.939,-86.939,-80.939,-78.939,-73.939,-47.939003,-11.939003,33.060997,23.060997,9.060997,0.06099701,-20.939003,-43.939003,-38.939003,-34.939003,-32.939003,-29.939003,-24.939003,-16.939003,-13.939003,-11.939003,-9.939003,-7.939003,-5.939003,-4.939003,-4.939003,-7.939003,-7.939003,-4.939003,-6.939003,-10.939003,-14.939003,-16.939003,-17.939003,-15.939003,-21.939003,-28.939003,-28.939003,-30.939003,-34.939003,-28.939003,-28.939003,-46.939003,-50.939003,-50.939003,-58.939003,-61.939003,-60.939003,-65.939,-70.939,-73.939,-77.939,-80.939,-85.939,-88.939,-89.939,-92.939,-93.939,-85.939,-14.939003,72.061,57.060997,43.060997,28.060997,-34.939003,-80.939,-73.939,-67.939,-62.939003,-58.939003,-53.939003,-49.939003,-45.939003,-42.939003,-42.939003,-39.939003,-36.939003,-31.939003,-27.939003,-23.939003,-20.939003,-17.939003,-16.939003,-12.939003,-7.939003,-7.939003,-5.939003,-1.939003,-4.939003,-6.939003,-4.939003,-5.939003,-8.939003,-14.939003,-17.939003,-17.939003,-17.939003,-20.939003,-27.939003,-33.939003,-36.939003,-38.939003,-41.939003,-47.939003,-36.939003,-25.939003,-18.939003,-29.939003,-42.939003,-13.939003,13.060997,37.060997,36.060997,33.060997,34.060997,36.060997,37.060997,36.060997,34.060997,31.060997,36.060997,31.060997,1.060997,-3.939003,3.060997,11.060997,18.060997,23.060997,14.060997,5.060997,0.06099701,-4.939003,-11.939003,-15.939003,-17.939003,-16.939003,-11.939003,-17.939003,-52.939003,-53.939003,-38.939003,-22.939003,-12.939003,-9.939003,-3.939003,0.06099701,2.060997,5.060997,10.060997,16.060997,19.060997,17.060997,14.060997,10.060997,0.06099701,-9.939003,-16.939003,10.060997,30.060997,42.060997,34.060997,28.060997,32.060997,33.060997,32.060997,25.060997,26.060997,35.060997,35.060997,29.060997,9.060997,2.060997,2.060997,22.060997,19.060997,-6.939003,6.060997,17.060997,10.060997,1.060997,-4.939003,22.060997,21.060997,-5.939003,5.060997,9.060997,-12.939003,-27.939003,-39.939003,-45.939003,-35.939003,-11.939003,-32.939003,-50.939003,-51.939003,-67.939,-84.939,-65.939,-43.939003,-19.939003,-21.939003,-23.939003,-21.939003,-22.939003,-26.939003,-30.939003,-33.939003,-34.939003,-36.939003,-38.939003,-40.939003,-44.939003,-49.939003,-55.939003,-59.939003,-59.939003,-62.939003,-63.939003,-62.939003,-64.939,-68.939,-73.939,-76.939,-78.939,-83.939,-86.939,-87.939,-87.939,-87.939,-82.939,-81.939,-82.939,-78.939,-75.939,-76.939,-55.939003,-25.939003,-3.939003,-3.939003,-22.939003,-42.939003,-55.939003,-50.939003,-47.939003,-45.939003,-38.939003,-34.939003,-33.939003,-31.939003,-30.939003,-29.939003,-28.939003,-27.939003,-29.939003,-32.939003,-35.939003,-39.939003,-41.939003,-39.939003,-44.939003,-51.939003,-50.939003,-51.939003,-54.939003,-60.939003,-66.939,-67.939,-61.939003,-52.939003,-44.939003,-40.939003,-42.939003,-40.939003,-42.939003,-55.939003,-52.939003,-40.939003,-11.939003,-3.939003,-14.939003,-22.939003,-25.939003,-17.939003,-17.939003,-21.939003,-17.939003,-20.939003,-30.939003,-32.939003,-32.939003,-31.939003,-32.939003,-36.939003,-39.939003,-40.939003,-41.939003,-40.939003,-39.939003,-41.939003,-40.939003,-40.939003,-36.939003,-32.939003,-30.939003,-29.939003,-29.939003,-34.939003,-35.939003,-33.939003,-33.939003,-33.939003,-35.939003,-32.939003,-32.939003,-35.939003,-35.939003,-34.939003,-27.939003,-22.939003,-18.939003,-17.939003,-16.939003,-17.939003,-15.939003,-11.939003,-15.939003,-21.939003,-30.939003,-17.939003,-10.939003,-20.939003,-24.939003,-27.939003,-29.939003,-29.939003,-30.939003,-32.939003,-33.939003,-31.939003,-28.939003,-26.939003,-29.939003,-31.939003,-31.939003,-34.939003,-36.939003,-34.939003,-33.939003,-34.939003,-34.939003,-34.939003,-33.939003,-34.939003,-36.939003,-37.939003,-32.939003,-26.939003,-33.939003,-41.939003,-48.939003,-48.939003,-47.939003,-47.939003,-47.939003,-48.939003,-54.939003,-56.939003,-54.939003,-56.939003,-59.939003,-60.939003,-59.939003,-59.939003,-63.939003,-63.939003,-57.939003,-63.939003,-68.939,-70.939,-72.939,-74.939,-77.939,-78.939,-78.939,-79.939,-80.939,-81.939,-83.939,-84.939,-81.939,-80.939,-79.939,-82.939,-84.939,-87.939,-79.939,-63.939003,-10.939003,20.060997,29.060997,-21.939003,-67.939,-91.939,-96.939,-94.939,-96.939,-97.939,-94.939,-93.939,-93.939,-94.939,-93.939,-91.939,-90.939,-92.939,-97.939,-98.939,-94.939,-84.939,-75.939,-69.939,-84.939,-89.939,-84.939,-92.939,-97.939,-98.939,-92.939,-85.939,-82.939,-82.939,-87.939,-75.939,-62.939003,-53.939003,-56.939003,-62.939003,-45.939003,-45.939003,-62.939003,-71.939,-72.939,-57.939003,-58.939003,-63.939003,-60.939003,-52.939003,-38.939003,-36.939003,-38.939003,-49.939003,-49.939003,-45.939003,-55.939003,-52.939003,-38.939003,-29.939003,-25.939003,-33.939003,-35.939003,-33.939003,-29.939003,-26.939003,-24.939003,-26.939003,-30.939003,-43.939003,-50.939003,-53.939003,-51.939003,-52.939003,-57.939003,-53.939003,-49.939003,-47.939003,-51.939003,-56.939003,-57.939003,-57.939003,-55.939003,-54.939003,-57.939003,-69.939,-67.939,-60.939003,-65.939,-63.939003,-57.939003,-60.939003,-63.939003,-11.939003,-14.939003,-20.939003,-22.939003,-25.939003,-29.939003,-29.939003,-29.939003,-30.939003,-28.939003,-24.939003,-26.939003,-23.939003,-15.939003,-8.939003,-2.939003,0.06099701,-15.939003,-34.939003,-39.939003,-46.939003,-53.939003,-61.939003,-68.939,-75.939,-80.939,-84.939,-91.939,-96.939,-97.939,-98.939,-99.939,-100.939,-100.939,-100.939,-100.939,-99.939,-98.939,-97.939,-96.939,-95.939,-92.939,-88.939,-89.939,-75.939,-46.939003,-29.939003,-30.939003,-68.939,-81.939,-81.939,-72.939,-67.939,-65.939,-62.939003,-58.939003,-52.939003,-46.939003,-39.939003,-38.939003,-35.939003,-28.939003,-22.939003,-17.939003,-14.939003,-13.939003,-11.939003,-2.939003,1.060997,-2.939003,2.060997,5.060997,6.060997,5.060997,5.060997,1.060997,-2.939003,-4.939003,-13.939003,-20.939003,-23.939003,-26.939003,-26.939003,19.060997,38.060997,34.060997,48.060997,46.060997,2.060997,-39.939003,-75.939,-61.939003,-56.939003,-59.939003,-61.939003,-61.939003,-56.939003,-44.939003,-28.939003,-16.939003,-9.939003,-7.939003,-0.939003,8.060997,21.060997,29.060997,37.060997,46.060997,54.060997,61.060997,65.061,69.061,71.061,76.061,79.061,75.061,69.061,61.060997,55.060997,48.060997,38.060997,33.060997,29.060997,21.060997,12.060997,4.060997,-2.939003,-8.939003,-15.939003,-19.939003,-23.939003,-31.939003,-35.939003,-34.939003,-36.939003,-37.939003,-36.939003,-29.939003,-21.939003,-15.939003,-10.939003,-6.939003,-5.939003,-1.939003,12.060997,19.060997,23.060997,23.060997,26.060997,34.060997,40.060997,47.060997,60.060997,43.060997,17.060997,11.060997,5.060997,-1.939003,-10.939003,-18.939003,-25.939003,-31.939003,-36.939003,-50.939003,-63.939003,-72.939,-81.939,-89.939,-92.939,-94.939,-94.939,-91.939,-89.939,-88.939,-86.939,-84.939,-82.939,-82.939,-81.939,-80.939,-79.939,-78.939,-70.939,-62.939003,-54.939003,-50.939003,-47.939003,-39.939003,-31.939003,-25.939003,-20.939003,-15.939003,-12.939003,-5.939003,0.06099701,3.060997,5.060997,8.060997,9.060997,10.060997,7.060997,6.060997,6.060997,8.060997,5.060997,-1.939003,-1.939003,-2.939003,-11.939003,-12.939003,-12.939003,-13.939003,-5.939003,12.060997,35.060997,51.060997,48.060997,51.060997,56.060997,61.060997,65.061,67.061,67.061,72.061,89.061,53.060997,3.060997,24.060997,36.060997,37.060997,62.060997,82.061,84.061,74.061,60.060997,59.060997,59.060997,60.060997,58.060997,56.060997,51.060997,48.060997,45.060997,38.060997,32.060997,26.060997,20.060997,15.060997,8.060997,1.060997,-4.939003,4.060997,2.060997,-9.939003,-8.939003,-10.939003,-25.939003,-58.939003,-95.939,-72.939,-62.939003,-65.939,-86.939,-99.939,-86.939,-40.939003,13.060997,10.060997,6.060997,2.060997,0.06099701,-2.939003,-9.939003,-13.939003,-16.939003,-19.939003,-27.939003,-38.939003,-47.939003,-53.939003,-56.939003,-61.939003,-66.939,-74.939,-81.939,-88.939,-95.939,-100.939,-101.939,-101.939,-100.939,-99.939,-98.939,-96.939,-95.939,-94.939,-92.939,-90.939,-89.939,-87.939,-88.939,-91.939,-92.939,-85.939,-58.939003,-37.939003,-26.939003,-77.939,-100.939,-95.939,-78.939,-61.939003,-45.939003,-40.939003,-39.939003,-32.939003,-24.939003,-16.939003,-19.939003,-24.939003,-23.939003,-13.939003,-0.939003,4.060997,6.060997,8.060997,7.060997,8.060997,12.060997,10.060997,6.060997,4.060997,2.060997,1.060997,-0.939003,-4.939003,-13.939003,-18.939003,-22.939003,-29.939003,-37.939003,-46.939003,-53.939003,-58.939003,-62.939003,-69.939,-75.939,-78.939,-81.939,-85.939,-62.939003,-54.939003,-91.939,-94.939,-84.939,-90.939,-94.939,-95.939,-93.939,-92.939,-89.939,-88.939,-85.939,-85.939,-84.939,-84.939,-82.939,-79.939,-72.939,-23.939003,37.060997,18.060997,5.060997,-3.939003,-29.939003,-46.939003,-37.939003,-28.939003,-20.939003,-15.939003,-10.939003,-6.939003,-4.939003,-2.939003,-1.939003,1.060997,4.060997,5.060997,5.060997,4.060997,3.060997,2.060997,0.06099701,-1.939003,-2.939003,-3.939003,-4.939003,-4.939003,-11.939003,-17.939003,-18.939003,-24.939003,-33.939003,-44.939003,-51.939003,-57.939003,-59.939003,-63.939003,-70.939,-78.939,-84.939,-88.939,-90.939,-91.939,-59.939003,-31.939003,-21.939003,-40.939003,-64.939,-20.939003,19.060997,54.060997,42.060997,30.060997,34.060997,32.060997,29.060997,20.060997,14.060997,10.060997,8.060997,1.060997,-17.939003,-21.939003,-17.939003,-17.939003,-19.939003,-25.939003,-31.939003,-36.939003,-37.939003,-43.939003,-50.939003,-48.939003,-46.939003,-45.939003,-37.939003,-38.939003,-59.939003,-53.939003,-35.939003,-19.939003,-12.939003,-15.939003,-8.939003,-1.939003,0.06099701,2.060997,5.060997,11.060997,16.060997,20.060997,10.060997,2.060997,-0.939003,-11.939003,-19.939003,14.060997,31.060997,32.060997,30.060997,29.060997,33.060997,33.060997,30.060997,19.060997,18.060997,30.060997,27.060997,18.060997,-1.939003,-14.939003,-20.939003,6.060997,20.060997,22.060997,14.060997,7.060997,5.060997,9.060997,16.060997,30.060997,13.060997,-33.939003,-4.939003,15.060997,-5.939003,-22.939003,-36.939003,-34.939003,-22.939003,-1.939003,-30.939003,-54.939003,-49.939003,-65.939,-84.939,-59.939003,-36.939003,-14.939003,-23.939003,-32.939003,-32.939003,-37.939003,-44.939003,-54.939003,-61.939003,-63.939003,-70.939,-76.939,-81.939,-84.939,-88.939,-92.939,-95.939,-96.939,-94.939,-91.939,-89.939,-89.939,-89.939,-88.939,-87.939,-86.939,-85.939,-84.939,-80.939,-75.939,-71.939,-64.939,-58.939003,-54.939003,-48.939003,-43.939003,-47.939003,-46.939003,-43.939003,-44.939003,-46.939003,-48.939003,-37.939003,-26.939003,-20.939003,-18.939003,-20.939003,-19.939003,-19.939003,-21.939003,-24.939003,-27.939003,-31.939003,-31.939003,-31.939003,-40.939003,-49.939003,-57.939003,-62.939003,-67.939,-72.939,-78.939,-83.939,-83.939,-86.939,-90.939,-92.939,-93.939,-92.939,-80.939,-62.939003,-49.939003,-41.939003,-40.939003,-38.939003,-41.939003,-56.939003,-51.939003,-38.939003,-24.939003,-21.939003,-30.939003,-31.939003,-34.939003,-39.939003,-42.939003,-45.939003,-41.939003,-39.939003,-39.939003,-42.939003,-45.939003,-52.939003,-51.939003,-49.939003,-45.939003,-42.939003,-40.939003,-35.939003,-31.939003,-29.939003,-24.939003,-19.939003,-17.939003,-10.939003,0.06099701,-3.939003,-13.939003,-38.939003,-48.939003,-49.939003,-46.939003,-46.939003,-48.939003,-47.939003,-45.939003,-40.939003,-35.939003,-31.939003,-25.939003,-20.939003,-16.939003,-16.939003,-16.939003,-17.939003,-13.939003,-9.939003,-13.939003,-22.939003,-36.939003,-5.939003,17.060997,14.060997,8.060997,2.060997,1.060997,-1.939003,-8.939003,-6.939003,-5.939003,-7.939003,-8.939003,-8.939003,-9.939003,-13.939003,-18.939003,-27.939003,-33.939003,-27.939003,-27.939003,-30.939003,-32.939003,-33.939003,-30.939003,-34.939003,-38.939003,-42.939003,-40.939003,-37.939003,-35.939003,-33.939003,-30.939003,-30.939003,-30.939003,-27.939003,-26.939003,-26.939003,-31.939003,-33.939003,-29.939003,-32.939003,-35.939003,-36.939003,-35.939003,-34.939003,-40.939003,-45.939003,-49.939003,-48.939003,-47.939003,-48.939003,-51.939003,-54.939003,-57.939003,-59.939003,-61.939003,-60.939003,-60.939003,-61.939003,-63.939003,-64.939,-60.939003,-59.939003,-59.939003,-61.939003,-64.939,-67.939,-59.939003,-46.939003,-26.939003,-12.939003,-2.939003,-33.939003,-62.939003,-76.939,-79.939,-79.939,-85.939,-87.939,-83.939,-83.939,-85.939,-88.939,-86.939,-84.939,-87.939,-91.939,-94.939,-94.939,-91.939,-82.939,-74.939,-69.939,-83.939,-88.939,-83.939,-90.939,-96.939,-99.939,-97.939,-94.939,-88.939,-82.939,-77.939,-68.939,-58.939003,-45.939003,-44.939003,-47.939003,-32.939003,-35.939003,-56.939003,-70.939,-75.939,-65.939,-63.939003,-62.939003,-56.939003,-51.939003,-47.939003,-43.939003,-43.939003,-53.939003,-57.939003,-57.939003,-55.939003,-49.939003,-38.939003,-34.939003,-32.939003,-37.939003,-36.939003,-33.939003,-29.939003,-27.939003,-29.939003,-30.939003,-34.939003,-47.939003,-56.939003,-59.939003,-53.939003,-53.939003,-56.939003,-52.939003,-48.939003,-47.939003,-52.939003,-57.939003,-59.939003,-58.939003,-54.939003,-54.939003,-56.939003,-66.939,-62.939003,-55.939003,-59.939003,-55.939003,-44.939003,-49.939003,-54.939003,-13.939003,-9.939003,-1.939003,9.060997,14.060997,12.060997,18.060997,27.060997,36.060997,39.060997,40.060997,43.060997,46.060997,50.060997,60.060997,65.061,58.060997,-4.939003,-83.939,-87.939,-90.939,-92.939,-93.939,-95.939,-97.939,-97.939,-96.939,-100.939,-101.939,-102.939,-99.939,-96.939,-95.939,-92.939,-90.939,-90.939,-88.939,-83.939,-81.939,-77.939,-70.939,-63.939003,-56.939003,-54.939003,-46.939003,-31.939003,-21.939003,-16.939003,-23.939003,-20.939003,-13.939003,-7.939003,-0.939003,6.060997,3.060997,3.060997,12.060997,14.060997,13.060997,9.060997,5.060997,0.06099701,-3.939003,-6.939003,-3.939003,-10.939003,-20.939003,-27.939003,-37.939003,-47.939003,-58.939003,-65.939,-67.939,-70.939,-75.939,-77.939,-77.939,-78.939,-74.939,-72.939,-72.939,-65.939,-45.939003,40.060997,80.061,70.061,74.061,69.061,40.060997,24.060997,16.060997,33.060997,47.060997,57.060997,59.060997,60.060997,61.060997,64.061,67.061,68.061,69.061,67.061,62.060997,56.060997,51.060997,47.060997,42.060997,35.060997,26.060997,16.060997,7.060997,-1.939003,-10.939003,-17.939003,-21.939003,-22.939003,-24.939003,-27.939003,-25.939003,-26.939003,-33.939003,-31.939003,-25.939003,-20.939003,-16.939003,-14.939003,-4.939003,3.060997,6.060997,14.060997,25.060997,26.060997,28.060997,29.060997,34.060997,37.060997,36.060997,36.060997,37.060997,33.060997,30.060997,27.060997,19.060997,10.060997,0.06099701,-11.939003,-22.939003,-23.939003,10.060997,79.061,82.061,86.061,109.061,59.060997,-17.939003,-57.939003,-78.939,-80.939,-82.939,-83.939,-85.939,-86.939,-88.939,-91.939,-91.939,-90.939,-83.939,-75.939,-70.939,-67.939,-65.939,-50.939003,-41.939003,-37.939003,-28.939003,-20.939003,-14.939003,-10.939003,-8.939003,-4.939003,0.06099701,5.060997,6.060997,8.060997,13.060997,15.060997,14.060997,11.060997,8.060997,6.060997,-1.939003,-8.939003,-10.939003,-10.939003,-11.939003,-18.939003,-27.939003,-40.939003,-36.939003,-34.939003,-44.939003,-48.939003,-54.939003,-66.939,-74.939,-79.939,-79.939,-77.939,-73.939,-51.939003,-24.939003,-21.939003,-6.939003,21.060997,61.060997,88.061,80.061,78.061,78.061,76.061,73.061,73.061,72.061,72.061,76.061,44.060997,2.060997,11.060997,15.060997,16.060997,28.060997,38.060997,37.060997,26.060997,11.060997,10.060997,6.060997,1.060997,-2.939003,-7.939003,-11.939003,-13.939003,-13.939003,-19.939003,-22.939003,-21.939003,-24.939003,-25.939003,-20.939003,-16.939003,-14.939003,-18.939003,-19.939003,-18.939003,-7.939003,-2.939003,-16.939003,-56.939003,-94.939,-23.939003,13.060997,15.060997,-51.939003,-101.939,-91.939,-82.939,-72.939,-69.939,-72.939,-78.939,-79.939,-79.939,-80.939,-81.939,-83.939,-84.939,-85.939,-88.939,-90.939,-92.939,-92.939,-93.939,-95.939,-96.939,-97.939,-98.939,-100.939,-101.939,-99.939,-95.939,-91.939,-87.939,-81.939,-73.939,-70.939,-65.939,-55.939003,-49.939003,-44.939003,-36.939003,-37.939003,-51.939003,-80.939,-83.939,-24.939003,27.060997,60.060997,-49.939003,-100.939,-94.939,-60.939003,-22.939003,15.060997,20.060997,10.060997,6.060997,-3.939003,-17.939003,-11.939003,-5.939003,-8.939003,-12.939003,-15.939003,-16.939003,-22.939003,-33.939003,-39.939003,-46.939003,-51.939003,-56.939003,-61.939003,-68.939,-74.939,-78.939,-79.939,-80.939,-82.939,-83.939,-84.939,-86.939,-88.939,-90.939,-91.939,-93.939,-94.939,-95.939,-96.939,-97.939,-98.939,-99.939,-72.939,-59.939003,-90.939,-85.939,-69.939,-64.939,-62.939003,-66.939,-62.939003,-55.939003,-44.939003,-36.939003,-30.939003,-25.939003,-22.939003,-19.939003,-12.939003,-7.939003,-11.939003,-18.939003,-27.939003,-32.939003,-31.939003,-22.939003,-5.939003,7.060997,5.060997,7.060997,10.060997,9.060997,5.060997,-0.939003,-5.939003,-11.939003,-15.939003,-18.939003,-22.939003,-31.939003,-40.939003,-49.939003,-54.939003,-59.939003,-64.939,-71.939,-79.939,-80.939,-80.939,-80.939,-82.939,-83.939,-83.939,-85.939,-87.939,-89.939,-91.939,-92.939,-93.939,-93.939,-94.939,-94.939,-95.939,-95.939,-89.939,-78.939,-51.939003,-27.939003,-19.939003,-28.939003,-43.939003,-29.939003,-18.939003,-11.939003,-20.939003,-27.939003,-23.939003,-26.939003,-32.939003,-38.939003,-40.939003,-40.939003,-43.939003,-42.939003,-32.939003,-32.939003,-36.939003,-33.939003,-34.939003,-40.939003,-35.939003,-28.939003,-22.939003,-21.939003,-21.939003,-13.939003,-10.939003,-10.939003,3.060997,0.06099701,-48.939003,-54.939003,-40.939003,-19.939003,-14.939003,-23.939003,-13.939003,-3.939003,2.060997,6.060997,8.060997,9.060997,13.060997,19.060997,1.060997,-9.939003,4.060997,-16.939003,-43.939003,3.060997,30.060997,34.060997,31.060997,29.060997,32.060997,31.060997,28.060997,18.060997,17.060997,21.060997,8.060997,-0.939003,1.060997,-6.939003,-14.939003,18.060997,35.060997,38.060997,10.060997,-10.939003,-10.939003,-5.939003,1.060997,8.060997,-1.939003,-28.939003,1.060997,17.060997,-12.939003,-21.939003,-21.939003,-13.939003,-10.939003,-12.939003,-33.939003,-49.939003,-48.939003,-58.939003,-73.939,-82.939,-85.939,-83.939,-85.939,-87.939,-87.939,-88.939,-89.939,-92.939,-93.939,-94.939,-95.939,-95.939,-90.939,-86.939,-83.939,-80.939,-77.939,-73.939,-64.939,-56.939003,-54.939003,-49.939003,-42.939003,-37.939003,-32.939003,-27.939003,-27.939003,-25.939003,-18.939003,-16.939003,-14.939003,-15.939003,-13.939003,-10.939003,-9.939003,-12.939003,-21.939003,-32.939003,-43.939003,-33.939003,-25.939003,-19.939003,-38.939003,-55.939003,-59.939003,-63.939003,-69.939,-74.939,-78.939,-80.939,-83.939,-85.939,-86.939,-86.939,-86.939,-89.939,-91.939,-92.939,-94.939,-94.939,-92.939,-90.939,-88.939,-85.939,-81.939,-78.939,-69.939,-62.939003,-57.939003,-49.939003,-41.939003,-38.939003,-36.939003,-35.939003,-33.939003,-31.939003,-30.939003,-32.939003,-36.939003,-46.939003,-48.939003,-42.939003,-36.939003,-35.939003,-43.939003,-40.939003,-33.939003,-30.939003,-29.939003,-31.939003,-26.939003,-17.939003,-4.939003,-2.939003,-3.939003,4.060997,8.060997,8.060997,7.060997,8.060997,14.060997,20.060997,23.060997,13.060997,12.060997,21.060997,5.060997,-16.939003,-43.939003,-49.939003,-47.939003,-45.939003,-42.939003,-41.939003,-42.939003,-40.939003,-35.939003,-34.939003,-32.939003,-23.939003,-15.939003,-11.939003,-14.939003,-18.939003,-20.939003,-18.939003,-15.939003,-17.939003,-26.939003,-42.939003,0.06099701,36.060997,40.060997,37.060997,31.060997,36.060997,32.060997,18.060997,26.060997,35.060997,41.060997,40.060997,37.060997,35.060997,34.060997,37.060997,-2.939003,-24.939003,7.060997,27.060997,41.060997,33.060997,30.060997,30.060997,24.060997,21.060997,24.060997,24.060997,17.060997,-40.939003,-74.939,-85.939,-82.939,-80.939,-81.939,-78.939,-75.939,-76.939,-75.939,-73.939,-72.939,-70.939,-67.939,-67.939,-67.939,-61.939003,-59.939003,-57.939003,-64.939,-67.939,-61.939003,-59.939003,-59.939003,-62.939003,-64.939,-67.939,-59.939003,-52.939003,-49.939003,-48.939003,-47.939003,-47.939003,-49.939003,-52.939003,-47.939003,-43.939003,-40.939003,-37.939003,-35.939003,-41.939003,-43.939003,-39.939003,-35.939003,-33.939003,-35.939003,-38.939003,-40.939003,-40.939003,-40.939003,-39.939003,-37.939003,-36.939003,-39.939003,-41.939003,-43.939003,-45.939003,-52.939003,-66.939,-83.939,-92.939,-84.939,-76.939,-71.939,-88.939,-94.939,-89.939,-93.939,-95.939,-89.939,-79.939,-69.939,-75.939,-71.939,-56.939003,-55.939003,-52.939003,-40.939003,-33.939003,-28.939003,-29.939003,-40.939003,-61.939003,-75.939,-83.939,-79.939,-69.939,-57.939003,-47.939003,-46.939003,-52.939003,-62.939003,-67.939,-59.939003,-56.939003,-54.939003,-45.939003,-38.939003,-33.939003,-32.939003,-31.939003,-28.939003,-30.939003,-34.939003,-37.939003,-38.939003,-37.939003,-38.939003,-42.939003,-52.939003,-61.939003,-69.939,-61.939003,-55.939003,-51.939003,-48.939003,-47.939003,-48.939003,-51.939003,-56.939003,-58.939003,-57.939003,-53.939003,-58.939003,-64.939,-67.939,-67.939,-66.939,-63.939003,-57.939003,-46.939003,-51.939003,-56.939003,-3.939003,6.060997,23.060997,47.060997,57.060997,53.060997,56.060997,62.060997,69.061,70.061,68.061,71.061,73.061,72.061,77.061,81.061,80.061,9.060997,-83.939,-95.939,-99.939,-92.939,-88.939,-85.939,-85.939,-84.939,-81.939,-78.939,-74.939,-69.939,-66.939,-62.939003,-57.939003,-52.939003,-47.939003,-43.939003,-40.939003,-38.939003,-31.939003,-26.939003,-23.939003,-19.939003,-16.939003,-16.939003,-16.939003,-13.939003,-12.939003,-12.939003,-13.939003,-12.939003,-9.939003,-8.939003,-8.939003,-9.939003,-15.939003,-19.939003,-18.939003,-20.939003,-24.939003,-26.939003,-30.939003,-34.939003,-36.939003,-35.939003,-28.939003,-29.939003,-33.939003,-39.939003,-42.939003,-40.939003,-43.939003,-45.939003,-41.939003,-36.939003,-33.939003,-32.939003,-29.939003,-25.939003,-17.939003,-10.939003,-7.939003,-2.939003,9.060997,56.060997,76.061,69.061,65.061,56.060997,39.060997,36.060997,39.060997,46.060997,50.060997,50.060997,50.060997,48.060997,44.060997,39.060997,34.060997,32.060997,29.060997,26.060997,21.060997,16.060997,12.060997,10.060997,8.060997,2.060997,-2.939003,-6.939003,-10.939003,-12.939003,-14.939003,-14.939003,-13.939003,-12.939003,-11.939003,-10.939003,-7.939003,-6.939003,-6.939003,-2.939003,3.060997,2.060997,3.060997,5.060997,8.060997,8.060997,2.060997,3.060997,8.060997,8.060997,6.060997,3.060997,0.06099701,-3.939003,-10.939003,-13.939003,-14.939003,-18.939003,-21.939003,-24.939003,-30.939003,-37.939003,-46.939003,-56.939003,-65.939,-66.939,-21.939003,69.061,71.061,71.061,86.061,43.060997,-21.939003,-60.939003,-76.939,-67.939,-63.939003,-60.939003,-58.939003,-54.939003,-50.939003,-49.939003,-46.939003,-42.939003,-33.939003,-25.939003,-22.939003,-21.939003,-21.939003,-11.939003,-6.939003,-7.939003,-2.939003,-0.939003,-0.939003,-1.939003,-2.939003,-5.939003,-6.939003,-7.939003,-9.939003,-11.939003,-13.939003,-15.939003,-18.939003,-26.939003,-32.939003,-34.939003,-42.939003,-48.939003,-50.939003,-51.939003,-53.939003,-59.939003,-66.939,-75.939,-73.939,-71.939,-75.939,-75.939,-77.939,-92.939,-98.939,-98.939,-99.939,-96.939,-86.939,-59.939003,-28.939003,-21.939003,-7.939003,12.060997,43.060997,62.060997,51.060997,47.060997,45.060997,41.060997,37.060997,33.060997,28.060997,24.060997,23.060997,9.060997,-7.939003,-3.939003,-2.939003,-1.939003,0.06099701,1.060997,-1.939003,-4.939003,-9.939003,-8.939003,-10.939003,-14.939003,-14.939003,-13.939003,-11.939003,-9.939003,-6.939003,-9.939003,-7.939003,-1.939003,0.06099701,1.060997,4.060997,2.060997,0.06099701,16.060997,21.060997,16.060997,35.060997,44.060997,16.060997,-40.939003,-96.939,-24.939003,17.060997,29.060997,-43.939003,-99.939,-91.939,-94.939,-100.939,-97.939,-98.939,-102.939,-102.939,-100.939,-98.939,-97.939,-95.939,-92.939,-88.939,-85.939,-81.939,-78.939,-76.939,-72.939,-68.939,-65.939,-61.939003,-59.939003,-55.939003,-52.939003,-50.939003,-45.939003,-38.939003,-33.939003,-27.939003,-20.939003,-19.939003,-17.939003,-9.939003,-8.939003,-11.939003,-6.939003,-14.939003,-37.939003,-75.939,-88.939,-38.939003,7.060997,36.060997,-50.939003,-91.939,-85.939,-63.939003,-39.939003,-12.939003,-15.939003,-29.939003,-18.939003,-0.939003,27.060997,32.060997,31.060997,26.060997,-11.939003,-56.939003,-57.939003,-62.939003,-70.939,-75.939,-80.939,-84.939,-88.939,-91.939,-96.939,-98.939,-98.939,-97.939,-96.939,-93.939,-90.939,-87.939,-85.939,-83.939,-81.939,-75.939,-69.939,-66.939,-65.939,-64.939,-58.939003,-55.939003,-54.939003,-41.939003,-34.939003,-43.939003,-39.939003,-30.939003,-25.939003,-22.939003,-23.939003,-21.939003,-17.939003,-11.939003,-7.939003,-2.939003,-1.939003,-1.939003,-3.939003,-3.939003,-5.939003,-10.939003,-11.939003,-8.939003,-7.939003,-4.939003,1.060997,-14.939003,-28.939003,-31.939003,-33.939003,-33.939003,-35.939003,-39.939003,-44.939003,-48.939003,-53.939003,-56.939003,-59.939003,-62.939003,-69.939,-75.939,-82.939,-86.939,-89.939,-91.939,-95.939,-98.939,-96.939,-92.939,-90.939,-85.939,-81.939,-82.939,-78.939,-73.939,-69.939,-66.939,-65.939,-62.939003,-59.939003,-57.939003,-53.939003,-50.939003,-47.939003,-41.939003,-34.939003,-25.939003,-17.939003,-14.939003,-16.939003,-20.939003,-23.939003,-25.939003,-26.939003,-30.939003,-34.939003,-32.939003,-32.939003,-32.939003,-31.939003,-29.939003,-27.939003,-22.939003,-18.939003,-18.939003,-18.939003,-17.939003,-7.939003,0.06099701,6.060997,8.060997,10.060997,14.060997,17.060997,20.060997,23.060997,23.060997,18.060997,32.060997,25.060997,-37.939003,-53.939003,-47.939003,-17.939003,-9.939003,-23.939003,-26.939003,-21.939003,-3.939003,5.060997,9.060997,11.060997,10.060997,4.060997,-5.939003,-10.939003,-1.939003,-13.939003,-28.939003,16.060997,38.060997,37.060997,35.060997,33.060997,36.060997,37.060997,34.060997,15.060997,9.060997,12.060997,4.060997,-3.939003,-7.939003,-7.939003,-2.939003,26.060997,35.060997,26.060997,9.060997,2.060997,21.060997,19.060997,7.060997,-3.939003,-12.939003,-19.939003,-2.939003,6.060997,-10.939003,-17.939003,-20.939003,-14.939003,-13.939003,-18.939003,-35.939003,-48.939003,-46.939003,-56.939003,-73.939,-85.939,-89.939,-83.939,-82.939,-80.939,-76.939,-73.939,-70.939,-66.939,-62.939003,-60.939003,-60.939003,-59.939003,-52.939003,-48.939003,-44.939003,-44.939003,-41.939003,-36.939003,-32.939003,-30.939003,-31.939003,-28.939003,-24.939003,-24.939003,-23.939003,-20.939003,-25.939003,-29.939003,-29.939003,-29.939003,-31.939003,-35.939003,-38.939003,-36.939003,-41.939003,-46.939003,-52.939003,-36.939003,-11.939003,8.060997,11.060997,-2.939003,-46.939003,-81.939,-85.939,-87.939,-88.939,-89.939,-90.939,-91.939,-88.939,-85.939,-79.939,-78.939,-77.939,-73.939,-70.939,-68.939,-67.939,-64.939,-56.939003,-53.939003,-52.939003,-47.939003,-44.939003,-42.939003,-38.939003,-35.939003,-34.939003,-33.939003,-32.939003,-36.939003,-36.939003,-34.939003,-34.939003,-35.939003,-35.939003,-34.939003,-31.939003,-28.939003,-27.939003,-29.939003,-26.939003,-23.939003,-21.939003,-18.939003,-14.939003,-10.939003,-14.939003,-26.939003,-20.939003,-8.939003,7.060997,7.060997,0.06099701,2.060997,1.060997,-1.939003,-6.939003,-8.939003,-5.939003,-3.939003,-2.939003,-12.939003,-15.939003,-10.939003,-18.939003,-27.939003,-36.939003,-36.939003,-33.939003,-31.939003,-28.939003,-25.939003,-28.939003,-31.939003,-36.939003,-36.939003,-33.939003,-23.939003,-16.939003,-14.939003,-14.939003,-16.939003,-18.939003,-17.939003,-15.939003,-17.939003,-22.939003,-30.939003,-12.939003,3.060997,3.060997,1.060997,-0.939003,0.06099701,-2.939003,-8.939003,-2.939003,3.060997,10.060997,10.060997,9.060997,9.060997,11.060997,16.060997,-13.939003,-30.939003,-5.939003,9.060997,17.060997,16.060997,15.060997,15.060997,12.060997,11.060997,15.060997,16.060997,12.060997,-45.939003,-81.939,-94.939,-91.939,-90.939,-93.939,-93.939,-92.939,-92.939,-92.939,-92.939,-92.939,-91.939,-88.939,-88.939,-87.939,-75.939,-70.939,-70.939,-80.939,-86.939,-82.939,-80.939,-79.939,-81.939,-82.939,-84.939,-78.939,-73.939,-71.939,-70.939,-69.939,-69.939,-71.939,-73.939,-69.939,-66.939,-63.939003,-51.939003,-35.939003,-11.939003,-2.939003,-5.939003,-31.939003,-52.939003,-56.939003,-56.939003,-56.939003,-55.939003,-54.939003,-51.939003,-48.939003,-47.939003,-50.939003,-54.939003,-56.939003,-54.939003,-59.939003,-71.939,-87.939,-95.939,-87.939,-80.939,-75.939,-84.939,-87.939,-85.939,-92.939,-95.939,-82.939,-70.939,-61.939003,-74.939,-70.939,-50.939003,-47.939003,-44.939003,-38.939003,-34.939003,-33.939003,-33.939003,-42.939003,-61.939003,-69.939,-71.939,-65.939,-58.939003,-50.939003,-44.939003,-45.939003,-50.939003,-61.939003,-66.939,-52.939003,-50.939003,-52.939003,-43.939003,-38.939003,-37.939003,-35.939003,-32.939003,-30.939003,-35.939003,-42.939003,-45.939003,-46.939003,-45.939003,-47.939003,-49.939003,-53.939003,-61.939003,-68.939,-59.939003,-50.939003,-40.939003,-37.939003,-37.939003,-41.939003,-47.939003,-54.939003,-54.939003,-51.939003,-47.939003,-55.939003,-64.939,-71.939,-77.939,-82.939,-77.939,-66.939,-47.939003,-47.939003,-48.939003,10.060997,25.060997,47.060997,82.061,94.061,86.061,83.061,81.061,83.061,79.061,76.061,76.061,75.061,69.061,67.061,70.061,78.061,20.060997,-61.939003,-82.939,-86.939,-72.939,-64.939,-57.939003,-57.939003,-55.939003,-51.939003,-43.939003,-34.939003,-24.939003,-22.939003,-18.939003,-10.939003,-5.939003,1.060997,8.060997,10.060997,8.060997,19.060997,26.060997,21.060997,18.060997,15.060997,10.060997,3.060997,-3.939003,-8.939003,-13.939003,-19.939003,-26.939003,-33.939003,-36.939003,-43.939003,-55.939003,-60.939003,-67.939,-75.939,-80.939,-84.939,-82.939,-80.939,-80.939,-74.939,-67.939,-54.939003,-45.939003,-37.939003,-36.939003,-26.939003,-7.939003,3.060997,12.060997,22.060997,36.060997,49.060997,52.060997,58.060997,65.061,74.061,82.061,86.061,85.061,80.061,61.060997,49.060997,48.060997,36.060997,26.060997,19.060997,23.060997,29.060997,23.060997,11.060997,-3.939003,-6.939003,-11.939003,-20.939003,-29.939003,-39.939003,-41.939003,-44.939003,-47.939003,-48.939003,-46.939003,-41.939003,-37.939003,-33.939003,-29.939003,-23.939003,-14.939003,-7.939003,0.06099701,10.060997,21.060997,31.060997,32.060997,35.060997,39.060997,39.060997,40.060997,45.060997,50.060997,52.060997,39.060997,31.060997,28.060997,17.060997,3.060997,-15.939003,-27.939003,-35.939003,-42.939003,-49.939003,-57.939003,-71.939,-84.939,-97.939,-100.939,-99.939,-99.939,-98.939,-97.939,-97.939,-96.939,-96.939,-95.939,-94.939,-94.939,-51.939003,33.060997,36.060997,32.060997,31.060997,13.060997,-12.939003,-34.939003,-37.939003,-18.939003,-10.939003,-4.939003,-1.939003,5.060997,14.060997,16.060997,18.060997,19.060997,25.060997,27.060997,22.060997,19.060997,17.060997,15.060997,11.060997,5.060997,1.060997,-5.939003,-15.939003,-21.939003,-27.939003,-40.939003,-50.939003,-59.939003,-62.939003,-66.939,-77.939,-81.939,-85.939,-95.939,-100.939,-99.939,-101.939,-101.939,-101.939,-100.939,-101.939,-100.939,-100.939,-99.939,-98.939,-96.939,-91.939,-84.939,-79.939,-91.939,-92.939,-85.939,-86.939,-83.939,-71.939,-50.939003,-27.939003,-16.939003,-8.939003,-2.939003,8.060997,12.060997,1.060997,-3.939003,-6.939003,-8.939003,-13.939003,-18.939003,-26.939003,-33.939003,-35.939003,-29.939003,-20.939003,-17.939003,-15.939003,-15.939003,-19.939003,-23.939003,-26.939003,-20.939003,-13.939003,-8.939003,-6.939003,-5.939003,0.06099701,7.060997,16.060997,23.060997,28.060997,31.060997,36.060997,46.060997,51.060997,53.060997,50.060997,35.060997,23.060997,68.061,82.061,64.061,89.061,96.061,49.060997,-25.939003,-98.939,-45.939003,-8.939003,13.060997,-48.939003,-96.939,-88.939,-92.939,-102.939,-99.939,-97.939,-96.939,-93.939,-90.939,-87.939,-83.939,-78.939,-70.939,-63.939003,-56.939003,-48.939003,-41.939003,-37.939003,-30.939003,-22.939003,-14.939003,-8.939003,-4.939003,2.060997,9.060997,7.060997,12.060997,20.060997,25.060997,29.060997,31.060997,28.060997,26.060997,28.060997,19.060997,7.060997,4.060997,-10.939003,-38.939003,-75.939,-94.939,-68.939,-43.939003,-26.939003,-63.939003,-80.939,-73.939,-74.939,-75.939,-74.939,-86.939,-99.939,-64.939,-4.939003,83.061,81.061,69.061,60.060997,-11.939003,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-96.939,-96.939,-96.939,-96.939,-93.939,-86.939,-83.939,-80.939,-74.939,-67.939,-61.939003,-56.939003,-52.939003,-47.939003,-35.939003,-24.939003,-17.939003,-15.939003,-13.939003,-2.939003,4.060997,6.060997,-0.939003,-2.939003,10.060997,11.060997,9.060997,9.060997,11.060997,13.060997,11.060997,9.060997,6.060997,4.060997,3.060997,-1.939003,-5.939003,-12.939003,-24.939003,-34.939003,-38.939003,-4.939003,43.060997,47.060997,46.060997,41.060997,-35.939003,-94.939,-96.939,-98.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-98.939,-97.939,-96.939,-96.939,-96.939,-94.939,-91.939,-87.939,-79.939,-73.939,-68.939,-57.939003,-48.939003,-49.939003,-42.939003,-31.939003,-23.939003,-18.939003,-15.939003,-10.939003,-4.939003,-2.939003,2.060997,8.060997,13.060997,14.060997,11.060997,1.060997,-8.939003,-8.939003,-7.939003,-3.939003,-13.939003,-16.939003,-14.939003,-14.939003,-15.939003,-15.939003,-11.939003,-4.939003,2.060997,7.060997,12.060997,26.060997,29.060997,5.060997,3.060997,12.060997,30.060997,49.060997,69.061,61.060997,53.060997,49.060997,51.060997,57.060997,52.060997,46.060997,38.060997,48.060997,37.060997,-29.939003,-52.939003,-53.939003,-15.939003,-3.939003,-20.939003,-40.939003,-44.939003,-12.939003,2.060997,9.060997,15.060997,6.060997,-15.939003,-11.939003,-6.939003,-9.939003,-6.939003,0.06099701,34.060997,45.060997,36.060997,36.060997,37.060997,41.060997,44.060997,43.060997,14.060997,1.060997,5.060997,5.060997,0.06099701,-15.939003,-6.939003,12.060997,30.060997,28.060997,5.060997,11.060997,28.060997,65.061,53.060997,18.060997,-10.939003,-21.939003,-11.939003,-10.939003,-9.939003,-5.939003,-13.939003,-25.939003,-24.939003,-22.939003,-21.939003,-37.939003,-48.939003,-45.939003,-58.939003,-76.939,-79.939,-71.939,-54.939003,-51.939003,-48.939003,-40.939003,-34.939003,-28.939003,-19.939003,-12.939003,-7.939003,-8.939003,-7.939003,-4.939003,-2.939003,-2.939003,-8.939003,-9.939003,-5.939003,-9.939003,-15.939003,-20.939003,-22.939003,-25.939003,-32.939003,-36.939003,-38.939003,-50.939003,-61.939003,-68.939,-72.939,-75.939,-82.939,-86.939,-86.939,-94.939,-100.939,-99.939,-46.939003,24.060997,51.060997,44.060997,4.060997,-54.939003,-97.939,-96.939,-93.939,-88.939,-83.939,-79.939,-77.939,-68.939,-58.939003,-47.939003,-43.939003,-42.939003,-34.939003,-28.939003,-24.939003,-22.939003,-17.939003,-6.939003,-5.939003,-8.939003,-3.939003,-3.939003,-8.939003,-14.939003,-19.939003,-23.939003,-27.939003,-32.939003,-38.939003,-39.939003,-36.939003,-38.939003,-43.939003,-53.939003,-45.939003,-27.939003,2.060997,7.060997,-9.939003,-13.939003,-8.939003,5.060997,6.060997,1.060997,5.060997,-3.939003,-24.939003,-20.939003,-12.939003,1.060997,-2.939003,-14.939003,-20.939003,-28.939003,-34.939003,-41.939003,-46.939003,-47.939003,-48.939003,-50.939003,-54.939003,-56.939003,-57.939003,-48.939003,-38.939003,-28.939003,-21.939003,-18.939003,-18.939003,-15.939003,-10.939003,-15.939003,-23.939003,-39.939003,-40.939003,-35.939003,-24.939003,-19.939003,-21.939003,-16.939003,-13.939003,-14.939003,-14.939003,-13.939003,-16.939003,-16.939003,-14.939003,-28.939003,-41.939003,-50.939003,-49.939003,-48.939003,-55.939003,-55.939003,-50.939003,-50.939003,-50.939003,-45.939003,-43.939003,-42.939003,-39.939003,-36.939003,-31.939003,-37.939003,-40.939003,-35.939003,-36.939003,-38.939003,-31.939003,-28.939003,-28.939003,-26.939003,-24.939003,-24.939003,-19.939003,-17.939003,-49.939003,-70.939,-80.939,-79.939,-79.939,-85.939,-87.939,-89.939,-90.939,-92.939,-94.939,-97.939,-98.939,-98.939,-98.939,-96.939,-81.939,-75.939,-80.939,-91.939,-98.939,-98.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-99.939,-100.939,-100.939,-100.939,-100.939,-100.939,-101.939,-78.939,-41.939003,29.060997,58.060997,48.060997,-25.939003,-85.939,-97.939,-96.939,-91.939,-92.939,-89.939,-83.939,-81.939,-80.939,-84.939,-87.939,-89.939,-83.939,-84.939,-91.939,-97.939,-99.939,-91.939,-85.939,-80.939,-76.939,-76.939,-77.939,-90.939,-96.939,-79.939,-68.939,-64.939,-78.939,-73.939,-48.939003,-43.939003,-39.939003,-38.939003,-40.939003,-43.939003,-39.939003,-43.939003,-57.939003,-56.939003,-51.939003,-42.939003,-42.939003,-44.939003,-44.939003,-45.939003,-45.939003,-53.939003,-56.939003,-42.939003,-44.939003,-50.939003,-45.939003,-43.939003,-44.939003,-39.939003,-34.939003,-36.939003,-44.939003,-54.939003,-53.939003,-53.939003,-54.939003,-55.939003,-56.939003,-55.939003,-59.939003,-62.939003,-54.939003,-43.939003,-28.939003,-25.939003,-27.939003,-34.939003,-43.939003,-52.939003,-48.939003,-44.939003,-41.939003,-50.939003,-61.939003,-75.939,-87.939,-97.939,-92.939,-77.939,-49.939003,-43.939003,-41.939003,17.060997,24.060997,36.060997,69.061,78.061,66.061,62.060997,56.060997,48.060997,40.060997,34.060997,29.060997,22.060997,13.060997,8.060997,5.060997,2.060997,-10.939003,-26.939003,-23.939003,-19.939003,-14.939003,-12.939003,-9.939003,-6.939003,-4.939003,-2.939003,-2.939003,-4.939003,-6.939003,-9.939003,-12.939003,-15.939003,-22.939003,-28.939003,-30.939003,-34.939003,-40.939003,-42.939003,-46.939003,-52.939003,-56.939003,-59.939003,-62.939003,-62.939003,-59.939003,-35.939003,-24.939003,-42.939003,-50.939003,-50.939003,-39.939003,-28.939003,-21.939003,-15.939003,-10.939003,-5.939003,0.06099701,7.060997,18.060997,27.060997,34.060997,40.060997,46.060997,51.060997,54.060997,57.060997,52.060997,50.060997,51.060997,48.060997,45.060997,41.060997,36.060997,30.060997,28.060997,24.060997,18.060997,16.060997,12.060997,3.060997,2.060997,1.060997,-3.939003,-9.939003,-14.939003,-20.939003,-24.939003,-25.939003,-18.939003,-10.939003,-9.939003,-9.939003,-8.939003,-1.939003,1.060997,-1.939003,5.060997,15.060997,19.060997,21.060997,22.060997,25.060997,26.060997,24.060997,20.060997,17.060997,22.060997,20.060997,13.060997,7.060997,4.060997,3.060997,-3.939003,-13.939003,-21.939003,-27.939003,-31.939003,-43.939003,-52.939003,-53.939003,-52.939003,-51.939003,-55.939003,-58.939003,-59.939003,-63.939003,-67.939,-74.939,-78.939,-80.939,-83.939,-82.939,-79.939,-78.939,-76.939,-76.939,-71.939,-63.939003,-57.939003,-49.939003,-40.939003,-36.939003,-31.939003,-25.939003,-20.939003,-15.939003,-9.939003,-6.939003,-6.939003,-10.939003,-14.939003,-14.939003,-11.939003,-6.939003,0.06099701,5.060997,7.060997,4.060997,-1.939003,-8.939003,-13.939003,-16.939003,-20.939003,-26.939003,-35.939003,-39.939003,-44.939003,-48.939003,-52.939003,-57.939003,-61.939003,-65.939,-67.939,-68.939,-70.939,-74.939,-76.939,-78.939,-82.939,-85.939,-88.939,-89.939,-91.939,-94.939,-95.939,-95.939,-98.939,-99.939,-100.939,-98.939,-94.939,-91.939,-85.939,-78.939,-75.939,-70.939,-64.939,-61.939003,-57.939003,-49.939003,-39.939003,-30.939003,-30.939003,-26.939003,-20.939003,-16.939003,-11.939003,-6.939003,-5.939003,-6.939003,-6.939003,-7.939003,-8.939003,-15.939003,-19.939003,-15.939003,-15.939003,-17.939003,-12.939003,-9.939003,-7.939003,-6.939003,-3.939003,5.060997,1.060997,-5.939003,-2.939003,0.06099701,3.060997,19.060997,34.060997,46.060997,45.060997,41.060997,46.060997,49.060997,51.060997,57.060997,60.060997,57.060997,61.060997,65.061,67.061,68.061,71.061,73.061,73.061,69.061,47.060997,26.060997,62.060997,74.061,61.060997,64.061,47.060997,-17.939003,-64.939,-93.939,-48.939003,-20.939003,-8.939003,-53.939003,-87.939,-80.939,-83.939,-87.939,-56.939003,-34.939003,-23.939003,-20.939003,-16.939003,-10.939003,-8.939003,-8.939003,-3.939003,-2.939003,-1.939003,2.060997,4.060997,2.060997,0.06099701,-0.939003,0.06099701,-1.939003,-9.939003,-12.939003,-15.939003,-18.939003,-23.939003,-29.939003,-31.939003,-34.939003,-37.939003,-42.939003,-46.939003,-50.939003,-57.939003,-66.939,-67.939,-70.939,-75.939,-82.939,-81.939,-57.939003,-32.939003,-13.939003,-58.939003,-78.939,-73.939,-75.939,-80.939,-90.939,-97.939,-100.939,-68.939,-13.939003,64.061,50.060997,30.060997,23.060997,-21.939003,-75.939,-68.939,-60.939003,-54.939003,-49.939003,-43.939003,-31.939003,-27.939003,-25.939003,-25.939003,-22.939003,-16.939003,-12.939003,-9.939003,-9.939003,-7.939003,-4.939003,-1.939003,-0.939003,-2.939003,-1.939003,-1.939003,-4.939003,-4.939003,-5.939003,-8.939003,-13.939003,-20.939003,-18.939003,-20.939003,-30.939003,-34.939003,-34.939003,-41.939003,-49.939003,-59.939003,-63.939003,-65.939,-66.939,-67.939,-67.939,-69.939,-70.939,-73.939,-77.939,-80.939,-80.939,-9.939003,82.061,80.061,71.061,54.060997,-33.939003,-98.939,-93.939,-88.939,-83.939,-79.939,-75.939,-71.939,-67.939,-62.939003,-56.939003,-53.939003,-50.939003,-45.939003,-38.939003,-31.939003,-28.939003,-26.939003,-22.939003,-20.939003,-18.939003,-12.939003,-10.939003,-11.939003,-6.939003,-2.939003,-2.939003,-1.939003,-2.939003,-4.939003,-5.939003,-4.939003,-5.939003,-7.939003,-15.939003,-20.939003,-26.939003,-27.939003,-32.939003,-37.939003,-30.939003,-23.939003,-23.939003,-33.939003,-43.939003,-23.939003,3.060997,36.060997,36.060997,34.060997,40.060997,41.060997,40.060997,39.060997,38.060997,37.060997,41.060997,34.060997,3.060997,-1.939003,3.060997,11.060997,19.060997,25.060997,18.060997,9.060997,0.06099701,-3.939003,-5.939003,-7.939003,-10.939003,-16.939003,-8.939003,-11.939003,-40.939003,-51.939003,-50.939003,-24.939003,-17.939003,-29.939003,-45.939003,-46.939003,-17.939003,1.060997,14.060997,13.060997,-0.939003,-26.939003,-21.939003,-11.939003,-0.939003,1.060997,-0.939003,9.060997,8.060997,-4.939003,3.060997,16.060997,36.060997,40.060997,37.060997,33.060997,26.060997,14.060997,10.060997,11.060997,25.060997,31.060997,35.060997,31.060997,26.060997,20.060997,28.060997,34.060997,28.060997,11.060997,-10.939003,-23.939003,-27.939003,-19.939003,-19.939003,-16.939003,-6.939003,-13.939003,-26.939003,-20.939003,-20.939003,-27.939003,-44.939003,-56.939003,-51.939003,-65.939,-81.939,-55.939003,-33.939003,-17.939003,-18.939003,-21.939003,-17.939003,-20.939003,-25.939003,-28.939003,-29.939003,-31.939003,-35.939003,-42.939003,-51.939003,-56.939003,-59.939003,-65.939,-68.939,-70.939,-72.939,-74.939,-75.939,-76.939,-77.939,-79.939,-81.939,-81.939,-85.939,-89.939,-91.939,-89.939,-88.939,-87.939,-83.939,-76.939,-73.939,-68.939,-61.939003,-39.939003,-13.939003,-8.939003,-14.939003,-29.939003,-37.939003,-40.939003,-34.939003,-31.939003,-30.939003,-29.939003,-28.939003,-26.939003,-25.939003,-24.939003,-24.939003,-24.939003,-26.939003,-27.939003,-29.939003,-31.939003,-37.939003,-42.939003,-41.939003,-48.939003,-56.939003,-56.939003,-61.939003,-69.939,-72.939,-74.939,-76.939,-68.939,-55.939003,-49.939003,-46.939003,-47.939003,-43.939003,-45.939003,-60.939003,-51.939003,-32.939003,-10.939003,-7.939003,-23.939003,-29.939003,-29.939003,-24.939003,-27.939003,-35.939003,-38.939003,-41.939003,-44.939003,-43.939003,-43.939003,-41.939003,-45.939003,-49.939003,-45.939003,-45.939003,-47.939003,-46.939003,-43.939003,-36.939003,-33.939003,-31.939003,-25.939003,-19.939003,-13.939003,-20.939003,-30.939003,-43.939003,-46.939003,-43.939003,-45.939003,-46.939003,-45.939003,-45.939003,-45.939003,-44.939003,-41.939003,-35.939003,-28.939003,-22.939003,-16.939003,-17.939003,-18.939003,-19.939003,-17.939003,-12.939003,-16.939003,-19.939003,-23.939003,-0.939003,14.060997,2.060997,-2.939003,-5.939003,-10.939003,-14.939003,-17.939003,-17.939003,-18.939003,-18.939003,-18.939003,-17.939003,-22.939003,-25.939003,-25.939003,-31.939003,-35.939003,-32.939003,-32.939003,-33.939003,-34.939003,-37.939003,-39.939003,-39.939003,-39.939003,-39.939003,-39.939003,-38.939003,-37.939003,-36.939003,-36.939003,-37.939003,-38.939003,-41.939003,-39.939003,-37.939003,-39.939003,-41.939003,-39.939003,-43.939003,-46.939003,-45.939003,-46.939003,-47.939003,-47.939003,-46.939003,-47.939003,-48.939003,-49.939003,-52.939003,-54.939003,-56.939003,-58.939003,-59.939003,-60.939003,-59.939003,-60.939003,-60.939003,-62.939003,-63.939003,-69.939,-72.939,-73.939,-73.939,-75.939,-81.939,-67.939,-44.939003,-4.939003,12.060997,7.060997,-39.939003,-78.939,-82.939,-83.939,-84.939,-88.939,-89.939,-86.939,-87.939,-88.939,-91.939,-92.939,-93.939,-94.939,-97.939,-99.939,-100.939,-98.939,-91.939,-83.939,-75.939,-71.939,-71.939,-75.939,-90.939,-101.939,-95.939,-91.939,-89.939,-85.939,-69.939,-41.939003,-49.939003,-57.939003,-57.939003,-45.939003,-28.939003,-28.939003,-30.939003,-34.939003,-32.939003,-34.939003,-45.939003,-50.939003,-51.939003,-51.939003,-50.939003,-48.939003,-55.939003,-59.939003,-55.939003,-53.939003,-50.939003,-44.939003,-43.939003,-48.939003,-49.939003,-49.939003,-51.939003,-57.939003,-66.939,-69.939,-72.939,-73.939,-68.939,-64.939,-66.939,-64.939,-60.939003,-50.939003,-41.939003,-33.939003,-36.939003,-42.939003,-48.939003,-46.939003,-43.939003,-45.939003,-46.939003,-46.939003,-50.939003,-56.939003,-66.939,-72.939,-76.939,-77.939,-70.939,-52.939003,-60.939003,-66.939,9.060997,12.060997,16.060997,34.060997,36.060997,24.060997,20.060997,15.060997,8.060997,2.060997,-1.939003,-6.939003,-13.939003,-18.939003,-20.939003,-22.939003,-25.939003,-22.939003,-15.939003,-7.939003,-5.939003,-7.939003,-6.939003,-6.939003,-11.939003,-12.939003,-12.939003,-15.939003,-19.939003,-23.939003,-26.939003,-31.939003,-35.939003,-41.939003,-46.939003,-46.939003,-47.939003,-49.939003,-50.939003,-51.939003,-51.939003,-50.939003,-48.939003,-45.939003,-40.939003,-33.939003,-16.939003,-10.939003,-27.939003,-23.939003,-9.939003,0.06099701,11.060997,19.060997,26.060997,31.060997,37.060997,41.060997,45.060997,51.060997,55.060997,59.060997,61.060997,62.060997,60.060997,58.060997,58.060997,51.060997,45.060997,38.060997,32.060997,26.060997,19.060997,12.060997,4.060997,3.060997,-0.939003,-6.939003,-8.939003,-12.939003,-18.939003,-17.939003,-15.939003,-15.939003,-19.939003,-26.939003,-24.939003,-21.939003,-22.939003,-16.939003,-9.939003,-2.939003,2.060997,7.060997,8.060997,8.060997,5.060997,10.060997,19.060997,21.060997,21.060997,18.060997,20.060997,18.060997,9.060997,2.060997,-2.939003,1.060997,-3.939003,-13.939003,-21.939003,-27.939003,-31.939003,-43.939003,-57.939003,-65.939,-71.939,-75.939,-85.939,-92.939,-91.939,-89.939,-85.939,-82.939,-79.939,-76.939,-73.939,-71.939,-69.939,-67.939,-65.939,-63.939003,-59.939003,-52.939003,-46.939003,-39.939003,-35.939003,-30.939003,-23.939003,-17.939003,-10.939003,-3.939003,-1.939003,2.060997,6.060997,8.060997,9.060997,10.060997,3.060997,-11.939003,-9.939003,-6.939003,-0.939003,3.060997,5.060997,-5.939003,-13.939003,-17.939003,-23.939003,-31.939003,-40.939003,-47.939003,-54.939003,-59.939003,-66.939,-75.939,-82.939,-87.939,-89.939,-93.939,-97.939,-101.939,-103.939,-102.939,-100.939,-98.939,-97.939,-96.939,-95.939,-91.939,-89.939,-87.939,-85.939,-82.939,-79.939,-76.939,-72.939,-68.939,-66.939,-65.939,-62.939003,-57.939003,-50.939003,-43.939003,-35.939003,-32.939003,-27.939003,-21.939003,-18.939003,-15.939003,-10.939003,-5.939003,-1.939003,2.060997,4.060997,6.060997,7.060997,7.060997,6.060997,1.060997,-3.939003,-6.939003,-7.939003,-6.939003,-4.939003,-2.939003,5.060997,6.060997,5.060997,11.060997,17.060997,24.060997,28.060997,34.060997,48.060997,35.060997,13.060997,11.060997,13.060997,19.060997,42.060997,66.061,85.061,79.061,65.061,68.061,70.061,71.061,73.061,73.061,64.061,63.060997,63.060997,63.060997,59.060997,55.060997,53.060997,51.060997,47.060997,30.060997,13.060997,31.060997,37.060997,31.060997,22.060997,-1.939003,-61.939003,-86.939,-91.939,-37.939003,-13.939003,-19.939003,-56.939003,-84.939,-82.939,-86.939,-86.939,-43.939003,-13.939003,3.060997,2.060997,1.060997,5.060997,2.060997,-1.939003,-2.939003,-4.939003,-9.939003,-9.939003,-11.939003,-16.939003,-20.939003,-25.939003,-26.939003,-31.939003,-40.939003,-46.939003,-51.939003,-53.939003,-60.939003,-69.939,-73.939,-78.939,-82.939,-85.939,-88.939,-93.939,-98.939,-102.939,-98.939,-94.939,-91.939,-89.939,-81.939,-58.939003,-34.939003,-19.939003,-62.939003,-84.939,-82.939,-84.939,-86.939,-88.939,-78.939,-62.939003,-45.939003,-16.939003,24.060997,9.060997,-6.939003,-8.939003,-20.939003,-34.939003,-27.939003,-21.939003,-16.939003,-13.939003,-9.939003,0.06099701,5.060997,5.060997,3.060997,2.060997,1.060997,1.060997,0.06099701,-3.939003,-7.939003,-9.939003,-7.939003,-9.939003,-15.939003,-20.939003,-25.939003,-30.939003,-31.939003,-33.939003,-40.939003,-47.939003,-55.939003,-45.939003,-43.939003,-65.939,-71.939,-69.939,-73.939,-83.939,-97.939,-100.939,-100.939,-98.939,-96.939,-94.939,-90.939,-87.939,-85.939,-83.939,-81.939,-78.939,-17.939003,59.060997,54.060997,44.060997,28.060997,-25.939003,-64.939,-57.939003,-50.939003,-42.939003,-38.939003,-34.939003,-31.939003,-27.939003,-23.939003,-19.939003,-15.939003,-12.939003,-11.939003,-6.939003,-0.939003,1.060997,1.060997,1.060997,-0.939003,-1.939003,-1.939003,-5.939003,-11.939003,-13.939003,-14.939003,-13.939003,-16.939003,-21.939003,-27.939003,-31.939003,-31.939003,-34.939003,-38.939003,-47.939003,-54.939003,-62.939003,-65.939,-70.939,-75.939,-51.939003,-30.939003,-29.939003,-44.939003,-61.939003,-31.939003,2.060997,41.060997,39.060997,34.060997,38.060997,36.060997,32.060997,29.060997,26.060997,22.060997,21.060997,14.060997,-8.939003,-12.939003,-9.939003,-9.939003,-11.939003,-13.939003,-17.939003,-22.939003,-30.939003,-35.939003,-38.939003,-35.939003,-34.939003,-35.939003,-27.939003,-25.939003,-39.939003,-47.939003,-50.939003,-32.939003,-26.939003,-30.939003,-43.939003,-46.939003,-25.939003,-3.939003,15.060997,8.060997,-5.939003,-27.939003,-15.939003,-1.939003,3.060997,-3.939003,-13.939003,3.060997,0.06099701,-24.939003,-14.939003,3.060997,33.060997,38.060997,33.060997,42.060997,40.060997,25.060997,17.060997,19.060997,44.060997,47.060997,40.060997,32.060997,27.060997,25.060997,34.060997,34.060997,9.060997,-3.939003,-12.939003,-16.939003,-21.939003,-25.939003,-20.939003,-13.939003,-8.939003,-12.939003,-20.939003,-25.939003,-30.939003,-37.939003,-50.939003,-59.939003,-55.939003,-68.939,-82.939,-51.939003,-34.939003,-30.939003,-34.939003,-37.939003,-35.939003,-41.939003,-49.939003,-54.939003,-58.939003,-62.939003,-66.939,-73.939,-84.939,-88.939,-90.939,-92.939,-94.939,-94.939,-92.939,-90.939,-89.939,-86.939,-84.939,-82.939,-80.939,-77.939,-75.939,-73.939,-71.939,-67.939,-64.939,-61.939003,-56.939003,-49.939003,-42.939003,-35.939003,-30.939003,-29.939003,-32.939003,-34.939003,-36.939003,-38.939003,-29.939003,-21.939003,-19.939003,-19.939003,-22.939003,-26.939003,-29.939003,-30.939003,-32.939003,-36.939003,-40.939003,-42.939003,-45.939003,-48.939003,-52.939003,-56.939003,-63.939003,-69.939,-73.939,-78.939,-83.939,-84.939,-87.939,-91.939,-89.939,-87.939,-85.939,-74.939,-59.939003,-47.939003,-44.939003,-47.939003,-42.939003,-41.939003,-50.939003,-46.939003,-36.939003,-28.939003,-27.939003,-34.939003,-38.939003,-40.939003,-39.939003,-41.939003,-45.939003,-49.939003,-49.939003,-46.939003,-45.939003,-43.939003,-38.939003,-39.939003,-41.939003,-34.939003,-31.939003,-32.939003,-31.939003,-26.939003,-15.939003,-12.939003,-10.939003,-8.939003,-3.939003,5.060997,-11.939003,-32.939003,-53.939003,-57.939003,-52.939003,-53.939003,-53.939003,-51.939003,-50.939003,-49.939003,-46.939003,-43.939003,-37.939003,-31.939003,-24.939003,-15.939003,-17.939003,-19.939003,-21.939003,-18.939003,-13.939003,-18.939003,-23.939003,-29.939003,12.060997,41.060997,27.060997,20.060997,17.060997,15.060997,11.060997,2.060997,5.060997,8.060997,6.060997,7.060997,9.060997,3.060997,-2.939003,-8.939003,-22.939003,-28.939003,-13.939003,-7.939003,-7.939003,-12.939003,-16.939003,-16.939003,-17.939003,-18.939003,-18.939003,-21.939003,-25.939003,-40.939003,-45.939003,-42.939003,-43.939003,-45.939003,-45.939003,-42.939003,-39.939003,-42.939003,-42.939003,-40.939003,-42.939003,-44.939003,-42.939003,-43.939003,-46.939003,-43.939003,-42.939003,-43.939003,-42.939003,-42.939003,-43.939003,-45.939003,-47.939003,-47.939003,-48.939003,-48.939003,-47.939003,-47.939003,-47.939003,-47.939003,-47.939003,-53.939003,-55.939003,-57.939003,-56.939003,-57.939003,-66.939,-58.939003,-43.939003,-28.939003,-22.939003,-23.939003,-43.939003,-60.939003,-62.939003,-64.939,-67.939,-70.939,-71.939,-69.939,-70.939,-72.939,-76.939,-77.939,-78.939,-82.939,-85.939,-89.939,-94.939,-95.939,-90.939,-81.939,-73.939,-71.939,-71.939,-76.939,-91.939,-100.939,-92.939,-89.939,-88.939,-80.939,-64.939,-40.939003,-54.939003,-65.939,-60.939003,-45.939003,-27.939003,-27.939003,-28.939003,-29.939003,-27.939003,-31.939003,-48.939003,-55.939003,-57.939003,-59.939003,-58.939003,-54.939003,-59.939003,-62.939003,-60.939003,-55.939003,-47.939003,-45.939003,-44.939003,-47.939003,-50.939003,-53.939003,-57.939003,-57.939003,-57.939003,-64.939,-73.939,-83.939,-79.939,-74.939,-74.939,-68.939,-58.939003,-50.939003,-45.939003,-43.939003,-50.939003,-56.939003,-53.939003,-50.939003,-47.939003,-51.939003,-51.939003,-48.939003,-51.939003,-56.939003,-59.939003,-60.939003,-60.939003,-65.939,-64.939,-58.939003,-70.939,-78.939,-12.939003,-12.939003,-13.939003,-21.939003,-30.939003,-38.939003,-41.939003,-41.939003,-37.939003,-34.939003,-32.939003,-33.939003,-32.939003,-26.939003,-20.939003,-13.939003,-6.939003,-13.939003,-27.939003,-34.939003,-42.939003,-50.939003,-46.939003,-49.939003,-70.939,-79.939,-82.939,-82.939,-80.939,-74.939,-74.939,-74.939,-70.939,-62.939003,-51.939003,-38.939003,-27.939003,-19.939003,-5.939003,10.060997,24.060997,36.060997,46.060997,60.060997,69.061,74.061,48.060997,27.060997,26.060997,53.060997,89.061,83.061,76.061,67.061,63.060997,58.060997,54.060997,44.060997,30.060997,16.060997,4.060997,-5.939003,-11.939003,-18.939003,-27.939003,-32.939003,-34.939003,-39.939003,-43.939003,-45.939003,-44.939003,-44.939003,-43.939003,-36.939003,-28.939003,-23.939003,-17.939003,-10.939003,-1.939003,8.060997,18.060997,25.060997,29.060997,24.060997,19.060997,14.060997,24.060997,33.060997,28.060997,28.060997,32.060997,42.060997,46.060997,42.060997,25.060997,10.060997,0.06099701,-13.939003,-27.939003,-34.939003,-44.939003,-58.939003,-63.939003,-71.939,-85.939,-92.939,-94.939,-93.939,-94.939,-95.939,-96.939,-96.939,-96.939,-98.939,-99.939,-100.939,-98.939,-94.939,-88.939,-79.939,-70.939,-60.939003,-50.939003,-41.939003,-32.939003,-21.939003,-14.939003,-6.939003,-1.939003,4.060997,10.060997,15.060997,19.060997,24.060997,24.060997,25.060997,23.060997,21.060997,19.060997,20.060997,18.060997,13.060997,7.060997,3.060997,-0.939003,-8.939003,-19.939003,-34.939003,-21.939003,18.060997,40.060997,58.060997,72.061,56.060997,22.060997,-53.939003,-92.939,-92.939,-94.939,-95.939,-97.939,-98.939,-99.939,-99.939,-100.939,-100.939,-101.939,-102.939,-102.939,-102.939,-102.939,-103.939,-102.939,-99.939,-93.939,-87.939,-85.939,-83.939,-80.939,-68.939,-60.939003,-56.939003,-50.939003,-41.939003,-30.939003,-23.939003,-17.939003,-7.939003,-1.939003,3.060997,6.060997,11.060997,20.060997,25.060997,29.060997,28.060997,28.060997,29.060997,28.060997,27.060997,25.060997,17.060997,7.060997,5.060997,1.060997,-2.939003,-14.939003,-26.939003,-34.939003,-27.939003,-16.939003,-16.939003,-9.939003,4.060997,40.060997,64.061,64.061,63.060997,62.060997,62.060997,67.061,78.061,77.061,80.061,92.061,70.061,36.060997,26.060997,24.060997,31.060997,51.060997,72.061,90.061,80.061,60.060997,59.060997,56.060997,52.060997,49.060997,45.060997,35.060997,28.060997,22.060997,18.060997,10.060997,0.06099701,-7.939003,-13.939003,-14.939003,-14.939003,-15.939003,-23.939003,-27.939003,-27.939003,-35.939003,-50.939003,-80.939,-92.939,-90.939,-10.939003,13.060997,-18.939003,-57.939003,-87.939,-95.939,-100.939,-98.939,-60.939003,-32.939003,-13.939003,-24.939003,-35.939003,-39.939003,-48.939003,-58.939003,-65.939,-71.939,-79.939,-84.939,-89.939,-94.939,-95.939,-96.939,-96.939,-96.939,-97.939,-98.939,-98.939,-98.939,-99.939,-100.939,-100.939,-101.939,-101.939,-100.939,-101.939,-100.939,-102.939,-101.939,-87.939,-83.939,-88.939,-96.939,-94.939,-69.939,-51.939003,-42.939003,-77.939,-97.939,-101.939,-101.939,-93.939,-68.939,-27.939003,12.060997,3.060997,-13.939003,-38.939003,-41.939003,-40.939003,-35.939003,-8.939003,24.060997,22.060997,19.060997,15.060997,10.060997,5.060997,1.060997,0.06099701,-1.939003,-9.939003,-19.939003,-33.939003,-41.939003,-48.939003,-56.939003,-66.939,-76.939,-76.939,-80.939,-87.939,-92.939,-96.939,-96.939,-96.939,-97.939,-97.939,-98.939,-99.939,-80.939,-72.939,-94.939,-99.939,-96.939,-88.939,-90.939,-102.939,-100.939,-95.939,-87.939,-81.939,-75.939,-65.939,-56.939003,-50.939003,-43.939003,-37.939003,-33.939003,-27.939003,-24.939003,-30.939003,-34.939003,-35.939003,-12.939003,6.060997,11.060997,16.060997,21.060997,21.060997,20.060997,20.060997,17.060997,15.060997,13.060997,13.060997,14.060997,5.060997,-0.939003,-4.939003,-7.939003,-13.939003,-23.939003,-29.939003,-35.939003,-46.939003,-58.939003,-67.939,-77.939,-83.939,-82.939,-85.939,-90.939,-94.939,-96.939,-96.939,-97.939,-97.939,-98.939,-99.939,-99.939,-100.939,-100.939,-100.939,-61.939003,-29.939003,-28.939003,-41.939003,-58.939003,-38.939003,-18.939003,2.060997,-5.939003,-14.939003,-22.939003,-26.939003,-28.939003,-27.939003,-28.939003,-32.939003,-33.939003,-32.939003,-30.939003,-27.939003,-25.939003,-34.939003,-42.939003,-48.939003,-45.939003,-43.939003,-44.939003,-42.939003,-39.939003,-30.939003,-23.939003,-18.939003,-6.939003,-4.939003,-25.939003,-41.939003,-51.939003,-39.939003,-30.939003,-21.939003,-35.939003,-44.939003,-37.939003,-12.939003,14.060997,0.06099701,-10.939003,-18.939003,7.060997,22.060997,2.060997,-20.939003,-38.939003,15.060997,21.060997,-21.939003,-17.939003,-0.939003,33.060997,37.060997,30.060997,40.060997,43.060997,39.060997,28.060997,24.060997,43.060997,40.060997,29.060997,33.060997,31.060997,22.060997,29.060997,29.060997,9.060997,8.060997,13.060997,10.060997,-3.939003,-30.939003,-13.939003,-0.939003,-10.939003,-9.939003,-7.939003,-38.939003,-52.939003,-51.939003,-55.939003,-57.939003,-58.939003,-67.939,-77.939,-67.939,-73.939,-95.939,-97.939,-97.939,-95.939,-96.939,-98.939,-99.939,-99.939,-99.939,-100.939,-100.939,-101.939,-98.939,-95.939,-90.939,-85.939,-77.939,-69.939,-64.939,-60.939003,-53.939003,-46.939003,-41.939003,-34.939003,-24.939003,-19.939003,-14.939003,-8.939003,-5.939003,-4.939003,-3.939003,-4.939003,-6.939003,-2.939003,-0.939003,-5.939003,-18.939003,-32.939003,-26.939003,-21.939003,-21.939003,-31.939003,-42.939003,-51.939003,-57.939003,-64.939,-75.939,-83.939,-87.939,-90.939,-93.939,-96.939,-97.939,-98.939,-98.939,-98.939,-99.939,-98.939,-99.939,-100.939,-95.939,-88.939,-85.939,-80.939,-74.939,-66.939,-58.939003,-50.939003,-46.939003,-42.939003,-34.939003,-32.939003,-35.939003,-35.939003,-32.939003,-25.939003,-30.939003,-39.939003,-50.939003,-51.939003,-42.939003,-40.939003,-39.939003,-38.939003,-35.939003,-30.939003,-27.939003,-27.939003,-31.939003,-26.939003,-13.939003,11.060997,13.060997,7.060997,14.060997,15.060997,11.060997,5.060997,4.060997,13.060997,14.060997,12.060997,-1.939003,-6.939003,-0.939003,-22.939003,-43.939003,-58.939003,-55.939003,-45.939003,-42.939003,-37.939003,-28.939003,-30.939003,-34.939003,-45.939003,-46.939003,-42.939003,-33.939003,-25.939003,-18.939003,-16.939003,-17.939003,-19.939003,-18.939003,-16.939003,-23.939003,-28.939003,-33.939003,9.060997,38.060997,24.060997,19.060997,19.060997,24.060997,21.060997,8.060997,19.060997,29.060997,31.060997,34.060997,38.060997,36.060997,30.060997,19.060997,-10.939003,-20.939003,23.060997,39.060997,41.060997,34.060997,33.060997,38.060997,37.060997,37.060997,38.060997,33.060997,19.060997,-58.939003,-97.939,-97.939,-98.939,-98.939,-98.939,-97.939,-96.939,-97.939,-97.939,-95.939,-94.939,-91.939,-87.939,-89.939,-91.939,-68.939,-61.939003,-69.939,-73.939,-75.939,-70.939,-72.939,-74.939,-68.939,-65.939,-64.939,-62.939003,-61.939003,-60.939003,-55.939003,-52.939003,-50.939003,-50.939003,-52.939003,-49.939003,-48.939003,-56.939003,-50.939003,-39.939003,-43.939003,-45.939003,-44.939003,-36.939003,-32.939003,-37.939003,-39.939003,-40.939003,-39.939003,-36.939003,-32.939003,-31.939003,-33.939003,-40.939003,-42.939003,-44.939003,-46.939003,-50.939003,-60.939003,-79.939,-92.939,-87.939,-80.939,-73.939,-74.939,-76.939,-81.939,-93.939,-95.939,-70.939,-62.939003,-61.939003,-63.939003,-58.939003,-45.939003,-57.939003,-63.939003,-47.939003,-41.939003,-39.939003,-38.939003,-39.939003,-42.939003,-42.939003,-44.939003,-51.939003,-57.939003,-62.939003,-66.939,-67.939,-62.939003,-64.939,-64.939,-55.939003,-49.939003,-42.939003,-48.939003,-47.939003,-40.939003,-41.939003,-45.939003,-56.939003,-44.939003,-27.939003,-37.939003,-56.939003,-85.939,-88.939,-87.939,-80.939,-69.939,-56.939003,-55.939003,-55.939003,-59.939003,-66.939,-67.939,-50.939003,-54.939003,-64.939,-66.939,-59.939003,-47.939003,-54.939003,-60.939003,-55.939003,-52.939003,-49.939003,-54.939003,-60.939003,-67.939,-73.939,-77.939,-7.939003,-7.939003,-8.939003,-0.939003,3.060997,5.060997,7.060997,9.060997,13.060997,17.060997,20.060997,24.060997,28.060997,32.060997,33.060997,38.060997,50.060997,23.060997,-16.939003,-32.939003,-35.939003,-25.939003,-11.939003,-2.939003,-5.939003,-6.939003,-6.939003,0.06099701,7.060997,15.060997,22.060997,27.060997,26.060997,28.060997,31.060997,37.060997,40.060997,40.060997,40.060997,40.060997,42.060997,39.060997,34.060997,36.060997,33.060997,29.060997,14.060997,2.060997,0.06099701,4.060997,10.060997,6.060997,1.060997,-5.939003,-4.939003,-4.939003,-4.939003,-7.939003,-9.939003,-8.939003,-8.939003,-9.939003,-9.939003,-9.939003,-6.939003,-2.939003,1.060997,1.060997,2.060997,7.060997,10.060997,12.060997,11.060997,14.060997,17.060997,11.060997,8.060997,10.060997,7.060997,4.060997,3.060997,-3.939003,-8.939003,21.060997,46.060997,64.061,62.060997,62.060997,66.061,19.060997,-42.939003,-37.939003,-36.939003,-38.939003,-46.939003,-53.939003,-58.939003,-64.939,-72.939,-75.939,-76.939,-78.939,-75.939,-73.939,-75.939,-73.939,-70.939,-65.939,-60.939003,-55.939003,-48.939003,-43.939003,-38.939003,-34.939003,-32.939003,-32.939003,-27.939003,-17.939003,-15.939003,-13.939003,-7.939003,-5.939003,-4.939003,0.06099701,2.060997,3.060997,0.06099701,-1.939003,-3.939003,-5.939003,-6.939003,-10.939003,-14.939003,-16.939003,-23.939003,-28.939003,-32.939003,-37.939003,-42.939003,-45.939003,-48.939003,-51.939003,-54.939003,-56.939003,-58.939003,-62.939003,-68.939,-75.939,-43.939003,29.060997,69.061,95.061,99.061,84.061,52.060997,-50.939003,-101.939,-98.939,-97.939,-95.939,-95.939,-91.939,-86.939,-83.939,-79.939,-75.939,-73.939,-68.939,-58.939003,-51.939003,-46.939003,-43.939003,-40.939003,-34.939003,-31.939003,-27.939003,-22.939003,-16.939003,-11.939003,-7.939003,-4.939003,-0.939003,-0.939003,0.06099701,5.060997,4.060997,0.06099701,-0.939003,-1.939003,-2.939003,-5.939003,-8.939003,-12.939003,-16.939003,-19.939003,-22.939003,-27.939003,-34.939003,-36.939003,-40.939003,-44.939003,-49.939003,-54.939003,-55.939003,-57.939003,-59.939003,-65.939,-71.939,-75.939,-55.939003,-27.939003,-23.939003,-11.939003,10.060997,51.060997,79.061,71.061,67.061,64.061,60.060997,59.060997,61.060997,56.060997,52.060997,56.060997,38.060997,13.060997,3.060997,3.060997,11.060997,16.060997,21.060997,27.060997,19.060997,7.060997,5.060997,2.060997,-1.939003,-5.939003,-8.939003,-7.939003,-6.939003,-6.939003,-9.939003,-11.939003,-12.939003,-16.939003,-17.939003,-9.939003,-8.939003,-8.939003,-3.939003,0.06099701,2.060997,-3.939003,-22.939003,-70.939,-89.939,-92.939,-54.939003,-40.939003,-50.939003,-70.939,-88.939,-98.939,-100.939,-97.939,-83.939,-73.939,-64.939,-70.939,-75.939,-77.939,-81.939,-86.939,-89.939,-92.939,-96.939,-98.939,-99.939,-100.939,-98.939,-95.939,-92.939,-88.939,-83.939,-80.939,-78.939,-74.939,-72.939,-70.939,-61.939003,-55.939003,-52.939003,-45.939003,-40.939003,-40.939003,-36.939003,-31.939003,-35.939003,-53.939003,-85.939,-96.939,-98.939,-83.939,-73.939,-66.939,-83.939,-95.939,-102.939,-103.939,-94.939,-65.939,-30.939003,1.060997,-12.939003,-14.939003,-2.939003,1.060997,3.060997,4.060997,-15.939003,-41.939003,-45.939003,-47.939003,-50.939003,-52.939003,-55.939003,-57.939003,-58.939003,-59.939003,-62.939003,-67.939,-74.939,-78.939,-81.939,-85.939,-89.939,-94.939,-94.939,-94.939,-92.939,-91.939,-89.939,-84.939,-81.939,-78.939,-75.939,-71.939,-66.939,-54.939003,-47.939003,-53.939003,-51.939003,-47.939003,-42.939003,-39.939003,-37.939003,-34.939003,-29.939003,-24.939003,-19.939003,-14.939003,-10.939003,-7.939003,-6.939003,-6.939003,-7.939003,-5.939003,-11.939003,-19.939003,-20.939003,-18.939003,-12.939003,-12.939003,-14.939003,-19.939003,-22.939003,-25.939003,-28.939003,-32.939003,-38.939003,-42.939003,-45.939003,-49.939003,-50.939003,-50.939003,-54.939003,-57.939003,-60.939003,-61.939003,-64.939,-69.939,-72.939,-75.939,-80.939,-85.939,-90.939,-93.939,-94.939,-91.939,-88.939,-86.939,-82.939,-78.939,-73.939,-69.939,-66.939,-62.939003,-57.939003,-51.939003,-46.939003,-44.939003,-44.939003,-30.939003,-19.939003,-17.939003,-20.939003,-24.939003,-22.939003,-21.939003,-22.939003,-26.939003,-30.939003,-31.939003,-28.939003,-24.939003,-24.939003,-23.939003,-22.939003,-19.939003,-17.939003,-15.939003,-14.939003,-12.939003,-7.939003,-2.939003,3.060997,3.060997,2.060997,5.060997,7.060997,9.060997,15.060997,18.060997,18.060997,34.060997,33.060997,-10.939003,-37.939003,-53.939003,-39.939003,-27.939003,-14.939003,-19.939003,-24.939003,-28.939003,-7.939003,17.060997,-1.939003,-12.939003,-15.939003,0.06099701,8.060997,-7.939003,-3.939003,8.060997,27.060997,29.060997,13.060997,16.060997,23.060997,36.060997,36.060997,31.060997,35.060997,37.060997,38.060997,34.060997,31.060997,38.060997,37.060997,34.060997,31.060997,28.060997,26.060997,26.060997,25.060997,20.060997,23.060997,26.060997,11.060997,-7.939003,-29.939003,-24.939003,-19.939003,-19.939003,-23.939003,-29.939003,-36.939003,-33.939003,-22.939003,-32.939003,-44.939003,-53.939003,-65.939,-76.939,-68.939,-71.939,-85.939,-81.939,-77.939,-73.939,-70.939,-66.939,-61.939003,-56.939003,-51.939003,-50.939003,-48.939003,-46.939003,-43.939003,-40.939003,-38.939003,-35.939003,-31.939003,-28.939003,-27.939003,-27.939003,-25.939003,-23.939003,-25.939003,-24.939003,-24.939003,-29.939003,-33.939003,-34.939003,-36.939003,-38.939003,-41.939003,-47.939003,-53.939003,-53.939003,-54.939003,-59.939003,-32.939003,6.060997,18.060997,13.060997,-10.939003,-48.939003,-77.939,-80.939,-81.939,-81.939,-82.939,-83.939,-82.939,-80.939,-76.939,-71.939,-68.939,-66.939,-63.939003,-59.939003,-54.939003,-51.939003,-49.939003,-47.939003,-45.939003,-42.939003,-38.939003,-36.939003,-38.939003,-35.939003,-32.939003,-31.939003,-35.939003,-39.939003,-38.939003,-37.939003,-36.939003,-35.939003,-35.939003,-40.939003,-36.939003,-29.939003,-17.939003,-17.939003,-29.939003,-24.939003,-18.939003,-13.939003,-11.939003,-11.939003,-10.939003,-18.939003,-34.939003,-28.939003,-16.939003,-2.939003,-3.939003,-9.939003,-10.939003,-14.939003,-20.939003,-26.939003,-29.939003,-25.939003,-29.939003,-34.939003,-37.939003,-36.939003,-31.939003,-35.939003,-37.939003,-37.939003,-34.939003,-30.939003,-31.939003,-28.939003,-20.939003,-23.939003,-30.939003,-43.939003,-45.939003,-42.939003,-34.939003,-26.939003,-20.939003,-17.939003,-16.939003,-17.939003,-18.939003,-19.939003,-27.939003,-30.939003,-28.939003,-18.939003,-12.939003,-21.939003,-27.939003,-29.939003,-25.939003,-25.939003,-30.939003,-28.939003,-25.939003,-22.939003,-19.939003,-18.939003,-16.939003,-16.939003,-18.939003,-29.939003,-32.939003,-16.939003,-10.939003,-9.939003,-9.939003,-8.939003,-5.939003,-6.939003,-6.939003,-2.939003,-0.939003,-6.939003,-55.939003,-79.939,-79.939,-77.939,-76.939,-79.939,-80.939,-82.939,-84.939,-86.939,-89.939,-91.939,-93.939,-91.939,-95.939,-98.939,-72.939,-65.939,-80.939,-87.939,-92.939,-90.939,-91.939,-92.939,-89.939,-87.939,-87.939,-86.939,-85.939,-85.939,-82.939,-81.939,-80.939,-80.939,-81.939,-79.939,-78.939,-82.939,-57.939003,-21.939003,7.060997,14.060997,0.06099701,-39.939003,-71.939,-73.939,-74.939,-74.939,-74.939,-72.939,-70.939,-70.939,-71.939,-74.939,-75.939,-75.939,-73.939,-75.939,-82.939,-91.939,-96.939,-89.939,-78.939,-67.939,-72.939,-77.939,-84.939,-95.939,-97.939,-76.939,-65.939,-60.939003,-62.939003,-60.939003,-53.939003,-55.939003,-56.939003,-53.939003,-51.939003,-50.939003,-47.939003,-48.939003,-51.939003,-50.939003,-51.939003,-54.939003,-58.939003,-61.939003,-62.939003,-61.939003,-59.939003,-59.939003,-58.939003,-53.939003,-47.939003,-42.939003,-43.939003,-43.939003,-41.939003,-39.939003,-38.939003,-44.939003,-46.939003,-47.939003,-54.939003,-66.939,-84.939,-84.939,-81.939,-70.939,-59.939003,-49.939003,-48.939003,-45.939003,-42.939003,-60.939003,-72.939,-68.939,-70.939,-76.939,-82.939,-76.939,-57.939003,-47.939003,-46.939003,-60.939003,-64.939,-63.939003,-57.939003,-57.939003,-63.939003,-63.939003,-62.939003,8.060997,4.060997,2.060997,30.060997,47.060997,52.060997,57.060997,59.060997,62.060997,64.061,68.061,75.061,78.061,77.061,75.061,77.061,87.061,53.060997,5.060997,-10.939003,-8.939003,10.060997,25.060997,38.060997,48.060997,51.060997,52.060997,61.060997,67.061,73.061,80.061,83.061,78.061,74.061,70.061,68.061,64.061,58.060997,48.060997,38.060997,31.060997,20.060997,8.060997,2.060997,-5.939003,-14.939003,-15.939003,-17.939003,-19.939003,-30.939003,-41.939003,-40.939003,-41.939003,-43.939003,-38.939003,-34.939003,-32.939003,-28.939003,-22.939003,-12.939003,-5.939003,-1.939003,2.060997,7.060997,14.060997,20.060997,24.060997,24.060997,26.060997,33.060997,34.060997,33.060997,28.060997,27.060997,25.060997,13.060997,5.060997,2.060997,-7.939003,-18.939003,-28.939003,-43.939003,-52.939003,3.060997,51.060997,91.061,77.061,67.061,77.061,4.060997,-93.939,-92.939,-90.939,-87.939,-86.939,-84.939,-82.939,-81.939,-80.939,-79.939,-74.939,-65.939,-57.939003,-49.939003,-41.939003,-35.939003,-30.939003,-24.939003,-18.939003,-11.939003,-2.939003,4.060997,8.060997,11.060997,14.060997,12.060997,16.060997,26.060997,21.060997,17.060997,17.060997,12.060997,5.060997,5.060997,1.060997,-4.939003,-13.939003,-21.939003,-27.939003,-34.939003,-40.939003,-49.939003,-57.939003,-64.939,-73.939,-80.939,-84.939,-90.939,-96.939,-100.939,-102.939,-101.939,-100.939,-100.939,-100.939,-99.939,-98.939,-98.939,-57.939003,24.060997,71.061,99.061,92.061,80.061,57.060997,-37.939003,-82.939,-76.939,-73.939,-70.939,-67.939,-61.939003,-53.939003,-48.939003,-43.939003,-37.939003,-33.939003,-27.939003,-14.939003,-5.939003,-0.939003,2.060997,6.060997,9.060997,8.060997,9.060997,15.060997,19.060997,22.060997,17.060997,16.060997,17.060997,10.060997,4.060997,4.060997,-1.939003,-10.939003,-18.939003,-25.939003,-29.939003,-36.939003,-43.939003,-55.939003,-64.939,-70.939,-74.939,-81.939,-90.939,-93.939,-97.939,-101.939,-102.939,-101.939,-100.939,-99.939,-97.939,-96.939,-95.939,-94.939,-67.939,-33.939003,-26.939003,-12.939003,11.060997,44.060997,64.061,52.060997,46.060997,42.060997,36.060997,32.060997,28.060997,21.060997,16.060997,14.060997,3.060997,-9.939003,-14.939003,-13.939003,-5.939003,-10.939003,-14.939003,-16.939003,-19.939003,-23.939003,-25.939003,-27.939003,-28.939003,-30.939003,-29.939003,-23.939003,-15.939003,-8.939003,-9.939003,-7.939003,-1.939003,-1.939003,1.060997,13.060997,9.060997,3.060997,27.060997,40.060997,40.060997,39.060997,16.060997,-56.939003,-85.939,-95.939,-87.939,-80.939,-72.939,-78.939,-86.939,-98.939,-100.939,-96.939,-99.939,-99.939,-98.939,-95.939,-93.939,-92.939,-91.939,-89.939,-87.939,-86.939,-85.939,-82.939,-81.939,-78.939,-73.939,-68.939,-63.939003,-57.939003,-48.939003,-45.939003,-41.939003,-36.939003,-34.939003,-31.939003,-18.939003,-10.939003,-7.939003,0.06099701,4.060997,1.060997,7.060997,14.060997,-3.939003,-35.939003,-83.939,-95.939,-99.939,-92.939,-86.939,-81.939,-86.939,-92.939,-100.939,-101.939,-95.939,-72.939,-49.939003,-29.939003,-39.939003,-16.939003,41.060997,46.060997,45.060997,42.060997,-20.939003,-98.939,-100.939,-101.939,-100.939,-99.939,-98.939,-97.939,-96.939,-94.939,-93.939,-91.939,-90.939,-88.939,-86.939,-84.939,-83.939,-82.939,-81.939,-77.939,-69.939,-64.939,-58.939003,-51.939003,-46.939003,-41.939003,-38.939003,-32.939003,-24.939003,-21.939003,-18.939003,-13.939003,-9.939003,-6.939003,-5.939003,-2.939003,5.060997,7.060997,9.060997,9.060997,10.060997,12.060997,8.060997,5.060997,2.060997,-3.939003,-8.939003,-9.939003,-0.939003,10.060997,13.060997,17.060997,23.060997,-15.939003,-49.939003,-59.939003,-66.939,-73.939,-77.939,-82.939,-90.939,-93.939,-95.939,-97.939,-98.939,-97.939,-95.939,-94.939,-93.939,-92.939,-90.939,-88.939,-86.939,-85.939,-85.939,-84.939,-83.939,-80.939,-77.939,-71.939,-65.939,-58.939003,-50.939003,-42.939003,-35.939003,-30.939003,-26.939003,-20.939003,-14.939003,-6.939003,-0.939003,0.06099701,-1.939003,-8.939003,-12.939003,-9.939003,-7.939003,-5.939003,-10.939003,-17.939003,-24.939003,-24.939003,-23.939003,-17.939003,-11.939003,-4.939003,-4.939003,-1.939003,3.060997,8.060997,8.060997,0.06099701,-0.939003,2.060997,18.060997,33.060997,45.060997,40.060997,35.060997,37.060997,37.060997,36.060997,38.060997,36.060997,32.060997,47.060997,44.060997,-7.939003,-36.939003,-53.939003,-40.939003,-25.939003,-9.939003,-6.939003,-7.939003,-17.939003,-1.939003,19.060997,0.06099701,-10.939003,-12.939003,-4.939003,-1.939003,-12.939003,11.060997,45.060997,35.060997,33.060997,40.060997,38.060997,36.060997,35.060997,33.060997,31.060997,31.060997,33.060997,36.060997,36.060997,35.060997,33.060997,35.060997,37.060997,31.060997,29.060997,31.060997,27.060997,24.060997,27.060997,29.060997,29.060997,12.060997,-2.939003,-15.939003,-27.939003,-36.939003,-34.939003,-42.939003,-52.939003,-31.939003,-15.939003,-2.939003,-21.939003,-41.939003,-52.939003,-65.939,-77.939,-61.939003,-55.939003,-58.939003,-51.939003,-44.939003,-41.939003,-37.939003,-31.939003,-26.939003,-21.939003,-16.939003,-14.939003,-11.939003,-10.939003,-10.939003,-10.939003,-11.939003,-13.939003,-13.939003,-16.939003,-19.939003,-22.939003,-24.939003,-26.939003,-31.939003,-36.939003,-42.939003,-53.939003,-62.939003,-67.939,-70.939,-74.939,-79.939,-84.939,-90.939,-91.939,-92.939,-93.939,-43.939003,23.060997,35.060997,24.060997,-10.939003,-54.939003,-86.939,-83.939,-79.939,-74.939,-67.939,-63.939003,-59.939003,-55.939003,-48.939003,-39.939003,-35.939003,-34.939003,-31.939003,-26.939003,-20.939003,-18.939003,-17.939003,-14.939003,-16.939003,-19.939003,-17.939003,-20.939003,-28.939003,-30.939003,-31.939003,-35.939003,-39.939003,-42.939003,-43.939003,-42.939003,-39.939003,-38.939003,-42.939003,-53.939003,-43.939003,-23.939003,2.060997,2.060997,-22.939003,-16.939003,-9.939003,-4.939003,-5.939003,-8.939003,-9.939003,-19.939003,-37.939003,-33.939003,-27.939003,-24.939003,-27.939003,-32.939003,-37.939003,-41.939003,-45.939003,-50.939003,-52.939003,-51.939003,-55.939003,-59.939003,-54.939003,-48.939003,-43.939003,-39.939003,-34.939003,-27.939003,-26.939003,-27.939003,-30.939003,-29.939003,-22.939003,-24.939003,-29.939003,-41.939003,-43.939003,-41.939003,-33.939003,-26.939003,-20.939003,-17.939003,-16.939003,-16.939003,-18.939003,-21.939003,-30.939003,-32.939003,-25.939003,-28.939003,-33.939003,-42.939003,-47.939003,-50.939003,-51.939003,-51.939003,-50.939003,-53.939003,-55.939003,-52.939003,-52.939003,-53.939003,-50.939003,-47.939003,-42.939003,-39.939003,-40.939003,-44.939003,-47.939003,-47.939003,-44.939003,-43.939003,-42.939003,-43.939003,-43.939003,-38.939003,-33.939003,-31.939003,-47.939003,-55.939003,-53.939003,-49.939003,-47.939003,-52.939003,-54.939003,-55.939003,-58.939003,-62.939003,-67.939,-72.939,-76.939,-76.939,-80.939,-83.939,-64.939,-61.939003,-74.939,-81.939,-86.939,-87.939,-87.939,-87.939,-88.939,-89.939,-89.939,-89.939,-89.939,-90.939,-90.939,-91.939,-91.939,-91.939,-90.939,-91.939,-92.939,-94.939,-61.939003,-14.939003,32.060997,45.060997,22.060997,-45.939003,-96.939,-95.939,-96.939,-98.939,-98.939,-98.939,-98.939,-98.939,-98.939,-100.939,-101.939,-100.939,-96.939,-97.939,-101.939,-101.939,-99.939,-90.939,-78.939,-65.939,-72.939,-78.939,-83.939,-94.939,-99.939,-86.939,-74.939,-64.939,-63.939003,-62.939003,-60.939003,-54.939003,-51.939003,-57.939003,-57.939003,-57.939003,-55.939003,-56.939003,-59.939003,-55.939003,-54.939003,-58.939003,-62.939003,-65.939,-61.939003,-58.939003,-57.939003,-58.939003,-58.939003,-54.939003,-48.939003,-43.939003,-42.939003,-42.939003,-42.939003,-35.939003,-31.939003,-36.939003,-50.939003,-65.939,-65.939,-69.939,-75.939,-73.939,-67.939,-58.939003,-48.939003,-40.939003,-39.939003,-34.939003,-28.939003,-51.939003,-73.939,-79.939,-76.939,-72.939,-85.939,-84.939,-68.939,-49.939003,-42.939003,-64.939,-71.939,-71.939,-58.939003,-54.939003,-57.939003,-54.939003,-52.939003,40.060997,29.060997,19.060997,65.061,85.061,81.061,82.061,81.061,77.061,79.061,83.061,84.061,81.061,73.061,76.061,76.061,72.061,60.060997,47.060997,50.060997,52.060997,55.060997,47.060997,41.060997,39.060997,37.060997,33.060997,30.060997,25.060997,18.060997,6.060997,-4.939003,-9.939003,-14.939003,-18.939003,-24.939003,-27.939003,-28.939003,-30.939003,-30.939003,-28.939003,-26.939003,-23.939003,-19.939003,-16.939003,-12.939003,-8.939003,-6.939003,-8.939003,0.06099701,13.060997,20.060997,26.060997,28.060997,32.060997,35.060997,34.060997,33.060997,33.060997,28.060997,23.060997,19.060997,19.060997,15.060997,5.060997,-4.939003,-16.939003,-27.939003,-34.939003,-40.939003,-50.939003,-61.939003,-69.939,-72.939,-73.939,-76.939,-78.939,-79.939,-81.939,-83.939,-86.939,-89.939,-87.939,-38.939003,11.060997,61.060997,42.060997,24.060997,29.060997,-10.939003,-62.939003,-56.939003,-47.939003,-35.939003,-27.939003,-21.939003,-13.939003,-8.939003,-5.939003,-4.939003,-0.939003,6.060997,10.060997,14.060997,18.060997,18.060997,16.060997,15.060997,13.060997,9.060997,8.060997,5.060997,-4.939003,-9.939003,-13.939003,-18.939003,-25.939003,-32.939003,-41.939003,-49.939003,-54.939003,-62.939003,-71.939,-73.939,-76.939,-80.939,-82.939,-84.939,-86.939,-87.939,-89.939,-91.939,-92.939,-94.939,-96.939,-98.939,-99.939,-100.939,-101.939,-102.939,-98.939,-92.939,-90.939,-89.939,-90.939,-87.939,-83.939,-80.939,-56.939003,-10.939003,16.060997,30.060997,17.060997,8.060997,2.060997,-16.939003,-19.939003,-6.939003,-3.939003,-0.939003,3.060997,7.060997,9.060997,12.060997,15.060997,17.060997,17.060997,17.060997,17.060997,12.060997,8.060997,3.060997,-2.939003,-8.939003,-12.939003,-16.939003,-22.939003,-31.939003,-42.939003,-49.939003,-54.939003,-58.939003,-69.939,-77.939,-78.939,-80.939,-82.939,-83.939,-85.939,-86.939,-87.939,-89.939,-92.939,-94.939,-95.939,-96.939,-98.939,-100.939,-100.939,-101.939,-100.939,-97.939,-93.939,-89.939,-84.939,-79.939,-73.939,-67.939,-63.939003,-46.939003,-25.939003,-21.939003,-11.939003,5.060997,0.06099701,-7.939003,-16.939003,-18.939003,-17.939003,-22.939003,-26.939003,-26.939003,-23.939003,-22.939003,-22.939003,-23.939003,-23.939003,-12.939003,-8.939003,-10.939003,-6.939003,-2.939003,2.060997,5.060997,7.060997,7.060997,10.060997,13.060997,21.060997,28.060997,31.060997,39.060997,48.060997,49.060997,51.060997,54.060997,56.060997,58.060997,61.060997,42.060997,20.060997,62.060997,80.061,73.061,82.061,60.060997,-37.939003,-81.939,-97.939,-78.939,-66.939,-61.939003,-69.939,-81.939,-97.939,-101.939,-101.939,-94.939,-88.939,-79.939,-70.939,-61.939003,-57.939003,-50.939003,-43.939003,-34.939003,-27.939003,-23.939003,-17.939003,-12.939003,-10.939003,-5.939003,1.060997,6.060997,11.060997,16.060997,15.060997,14.060997,14.060997,14.060997,14.060997,13.060997,11.060997,9.060997,1.060997,-6.939003,-16.939003,-13.939003,-10.939003,-28.939003,-52.939003,-81.939,-91.939,-95.939,-88.939,-79.939,-71.939,-81.939,-89.939,-95.939,-96.939,-96.939,-95.939,-90.939,-84.939,-73.939,-17.939003,82.061,75.061,60.060997,58.060997,-12.939003,-101.939,-96.939,-92.939,-90.939,-85.939,-79.939,-75.939,-71.939,-66.939,-59.939003,-53.939003,-45.939003,-37.939003,-29.939003,-21.939003,-15.939003,-11.939003,-7.939003,-3.939003,1.060997,4.060997,6.060997,10.060997,12.060997,13.060997,14.060997,16.060997,19.060997,9.060997,2.060997,2.060997,-0.939003,-4.939003,-10.939003,-15.939003,-18.939003,-23.939003,-30.939003,-38.939003,-44.939003,-49.939003,-59.939003,-67.939,-72.939,-77.939,-81.939,-81.939,-12.939003,76.061,78.061,74.061,65.061,-22.939003,-90.939,-93.939,-95.939,-96.939,-96.939,-97.939,-98.939,-95.939,-91.939,-86.939,-82.939,-78.939,-70.939,-65.939,-62.939003,-57.939003,-49.939003,-36.939003,-30.939003,-27.939003,-26.939003,-22.939003,-16.939003,-11.939003,-7.939003,-2.939003,1.060997,5.060997,9.060997,12.060997,13.060997,12.060997,12.060997,12.060997,9.060997,4.060997,4.060997,-0.939003,-10.939003,-15.939003,-17.939003,-12.939003,-20.939003,-30.939003,-15.939003,1.060997,20.060997,24.060997,28.060997,33.060997,34.060997,33.060997,39.060997,43.060997,48.060997,53.060997,47.060997,11.060997,4.060997,12.060997,30.060997,40.060997,42.060997,30.060997,19.060997,14.060997,8.060997,1.060997,-6.939003,-11.939003,-12.939003,-8.939003,-11.939003,-32.939003,-43.939003,-48.939003,-43.939003,-30.939003,-9.939003,-2.939003,0.06099701,-5.939003,3.060997,14.060997,5.060997,-1.939003,-7.939003,0.06099701,4.060997,-4.939003,15.060997,42.060997,34.060997,32.060997,36.060997,26.060997,19.060997,26.060997,26.060997,26.060997,29.060997,31.060997,31.060997,31.060997,30.060997,31.060997,31.060997,32.060997,38.060997,40.060997,38.060997,32.060997,26.060997,22.060997,17.060997,14.060997,21.060997,26.060997,27.060997,-6.939003,-39.939003,-63.939003,-71.939,-67.939,-23.939003,-6.939003,-17.939003,-43.939003,-62.939003,-62.939003,-71.939,-79.939,-38.939003,-16.939003,-13.939003,-8.939003,-5.939003,-8.939003,-10.939003,-12.939003,-18.939003,-26.939003,-34.939003,-32.939003,-33.939003,-40.939003,-45.939003,-51.939003,-60.939003,-66.939,-72.939,-75.939,-78.939,-83.939,-85.939,-85.939,-86.939,-88.939,-89.939,-91.939,-94.939,-94.939,-92.939,-89.939,-88.939,-85.939,-79.939,-74.939,-69.939,-63.939003,-44.939003,-21.939003,-26.939003,-30.939003,-35.939003,-33.939003,-27.939003,-23.939003,-20.939003,-17.939003,-12.939003,-10.939003,-11.939003,-10.939003,-11.939003,-13.939003,-17.939003,-23.939003,-28.939003,-31.939003,-31.939003,-39.939003,-46.939003,-52.939003,-57.939003,-62.939003,-68.939,-76.939,-84.939,-86.939,-86.939,-87.939,-73.939,-54.939003,-45.939003,-42.939003,-44.939003,-47.939003,-50.939003,-55.939003,-44.939003,-27.939003,-22.939003,-24.939003,-32.939003,-33.939003,-34.939003,-39.939003,-42.939003,-44.939003,-44.939003,-41.939003,-36.939003,-42.939003,-47.939003,-51.939003,-51.939003,-49.939003,-47.939003,-43.939003,-36.939003,-35.939003,-34.939003,-26.939003,-21.939003,-17.939003,-18.939003,-13.939003,-1.939003,-21.939003,-41.939003,-52.939003,-54.939003,-52.939003,-55.939003,-53.939003,-45.939003,-39.939003,-37.939003,-41.939003,-40.939003,-36.939003,-31.939003,-25.939003,-18.939003,-16.939003,-16.939003,-17.939003,-18.939003,-19.939003,-32.939003,-33.939003,-25.939003,10.060997,30.060997,9.060997,4.060997,7.060997,-2.939003,-9.939003,-11.939003,-8.939003,-7.939003,-10.939003,-11.939003,-11.939003,-18.939003,-20.939003,-17.939003,-24.939003,-30.939003,-26.939003,-26.939003,-29.939003,-32.939003,-35.939003,-38.939003,-36.939003,-35.939003,-37.939003,-36.939003,-33.939003,-33.939003,-31.939003,-27.939003,-24.939003,-22.939003,-24.939003,-22.939003,-20.939003,-18.939003,-20.939003,-27.939003,-27.939003,-28.939003,-29.939003,-28.939003,-25.939003,-36.939003,-39.939003,-34.939003,-31.939003,-31.939003,-35.939003,-36.939003,-34.939003,-38.939003,-41.939003,-44.939003,-44.939003,-44.939003,-46.939003,-48.939003,-50.939003,-50.939003,-49.939003,-49.939003,-52.939003,-57.939003,-62.939003,-53.939003,-36.939003,-18.939003,-14.939003,-23.939003,-52.939003,-72.939,-68.939,-73.939,-81.939,-79.939,-79.939,-80.939,-81.939,-83.939,-90.939,-92.939,-93.939,-95.939,-97.939,-102.939,-100.939,-95.939,-86.939,-79.939,-74.939,-77.939,-77.939,-75.939,-89.939,-100.939,-99.939,-86.939,-70.939,-65.939,-62.939003,-63.939003,-57.939003,-53.939003,-53.939003,-52.939003,-52.939003,-58.939003,-63.939003,-68.939,-55.939003,-49.939003,-64.939,-74.939,-79.939,-71.939,-65.939,-61.939003,-68.939,-71.939,-64.939,-55.939003,-48.939003,-54.939003,-51.939003,-38.939003,-25.939003,-22.939003,-41.939003,-55.939003,-65.939,-54.939003,-48.939003,-50.939003,-45.939003,-42.939003,-43.939003,-37.939003,-29.939003,-27.939003,-26.939003,-26.939003,-46.939003,-63.939003,-67.939,-54.939003,-37.939003,-54.939003,-69.939,-81.939,-70.939,-61.939003,-62.939003,-60.939003,-57.939003,-53.939003,-49.939003,-46.939003,-50.939003,-53.939003,57.060997,38.060997,18.060997,62.060997,83.061,81.061,76.061,70.061,65.061,62.060997,60.060997,57.060997,50.060997,40.060997,40.060997,38.060997,30.060997,22.060997,15.060997,18.060997,18.060997,14.060997,7.060997,0.06099701,-4.939003,-3.939003,-2.939003,-2.939003,-2.939003,-3.939003,-10.939003,-16.939003,-16.939003,-14.939003,-11.939003,-13.939003,-12.939003,-8.939003,-6.939003,-4.939003,-0.939003,0.06099701,1.060997,4.060997,5.060997,5.060997,4.060997,0.06099701,-8.939003,-7.939003,-1.939003,-0.939003,-0.939003,-1.939003,-2.939003,-4.939003,-11.939003,-15.939003,-16.939003,-22.939003,-26.939003,-30.939003,-31.939003,-34.939003,-43.939003,-52.939003,-61.939003,-67.939,-71.939,-75.939,-78.939,-81.939,-81.939,-79.939,-75.939,-71.939,-68.939,-66.939,-62.939003,-57.939003,-53.939003,-50.939003,-45.939003,-26.939003,-8.939003,6.060997,-1.939003,-9.939003,-9.939003,-17.939003,-24.939003,-14.939003,-8.939003,-3.939003,-3.939003,-2.939003,1.060997,2.060997,1.060997,-3.939003,-6.939003,-6.939003,-7.939003,-10.939003,-12.939003,-17.939003,-22.939003,-28.939003,-32.939003,-35.939003,-38.939003,-41.939003,-50.939003,-54.939003,-57.939003,-61.939003,-66.939,-73.939,-79.939,-84.939,-88.939,-92.939,-97.939,-99.939,-101.939,-103.939,-100.939,-98.939,-96.939,-95.939,-93.939,-89.939,-85.939,-81.939,-76.939,-71.939,-66.939,-64.939,-62.939003,-57.939003,-50.939003,-42.939003,-39.939003,-38.939003,-36.939003,-34.939003,-30.939003,-24.939003,-17.939003,-8.939003,-6.939003,-6.939003,-13.939003,-13.939003,-11.939003,-5.939003,-2.939003,0.06099701,-0.939003,-2.939003,-3.939003,-8.939003,-13.939003,-14.939003,-17.939003,-20.939003,-21.939003,-23.939003,-29.939003,-34.939003,-39.939003,-43.939003,-47.939003,-53.939003,-56.939003,-59.939003,-64.939,-71.939,-80.939,-84.939,-87.939,-91.939,-97.939,-102.939,-102.939,-99.939,-97.939,-96.939,-94.939,-92.939,-88.939,-83.939,-77.939,-73.939,-70.939,-68.939,-65.939,-60.939003,-56.939003,-54.939003,-52.939003,-46.939003,-38.939003,-33.939003,-29.939003,-27.939003,-21.939003,-15.939003,-11.939003,-8.939003,-8.939003,-10.939003,-7.939003,-1.939003,-8.939003,-14.939003,-15.939003,-14.939003,-12.939003,-12.939003,-12.939003,-10.939003,-3.939003,2.060997,8.060997,6.060997,1.060997,1.060997,1.060997,4.060997,17.060997,31.060997,49.060997,50.060997,44.060997,45.060997,48.060997,54.060997,58.060997,62.060997,63.060997,65.061,68.061,69.061,70.061,71.061,66.061,63.060997,66.061,44.060997,17.060997,50.060997,62.060997,52.060997,59.060997,41.060997,-38.939003,-80.939,-100.939,-81.939,-71.939,-68.939,-75.939,-84.939,-94.939,-99.939,-98.939,-61.939003,-35.939003,-21.939003,-17.939003,-14.939003,-11.939003,-8.939003,-7.939003,-4.939003,-3.939003,-2.939003,-1.939003,-1.939003,-5.939003,-6.939003,-6.939003,-8.939003,-10.939003,-12.939003,-17.939003,-22.939003,-25.939003,-27.939003,-27.939003,-31.939003,-34.939003,-37.939003,-43.939003,-50.939003,-57.939003,-56.939003,-55.939003,-65.939,-76.939,-87.939,-93.939,-97.939,-90.939,-82.939,-74.939,-81.939,-88.939,-97.939,-97.939,-94.939,-91.939,-83.939,-73.939,-61.939003,-22.939003,44.060997,36.060997,23.060997,23.060997,-12.939003,-56.939003,-47.939003,-41.939003,-38.939003,-32.939003,-27.939003,-25.939003,-22.939003,-16.939003,-11.939003,-8.939003,-7.939003,-4.939003,-2.939003,1.060997,2.060997,2.060997,2.060997,0.06099701,-2.939003,-3.939003,-6.939003,-11.939003,-15.939003,-19.939003,-20.939003,-23.939003,-26.939003,-22.939003,-23.939003,-37.939003,-43.939003,-47.939003,-49.939003,-53.939003,-59.939003,-64.939,-69.939,-75.939,-79.939,-82.939,-89.939,-93.939,-93.939,-94.939,-94.939,-89.939,-22.939003,61.060997,55.060997,47.060997,36.060997,-25.939003,-70.939,-64.939,-61.939003,-58.939003,-55.939003,-52.939003,-50.939003,-45.939003,-40.939003,-35.939003,-33.939003,-31.939003,-25.939003,-21.939003,-20.939003,-17.939003,-12.939003,-9.939003,-6.939003,-5.939003,-6.939003,-6.939003,-7.939003,-8.939003,-10.939003,-8.939003,-11.939003,-14.939003,-17.939003,-20.939003,-22.939003,-26.939003,-29.939003,-30.939003,-35.939003,-40.939003,-42.939003,-46.939003,-52.939003,-41.939003,-28.939003,-18.939003,-31.939003,-48.939003,-24.939003,4.060997,34.060997,32.060997,29.060997,33.060997,28.060997,21.060997,22.060997,22.060997,23.060997,25.060997,20.060997,-4.939003,-9.939003,-5.939003,-0.939003,0.06099701,-2.939003,-7.939003,-13.939003,-20.939003,-24.939003,-27.939003,-30.939003,-32.939003,-32.939003,-25.939003,-22.939003,-32.939003,-40.939003,-47.939003,-44.939003,-31.939003,-10.939003,-5.939003,-3.939003,-6.939003,1.060997,13.060997,4.060997,-1.939003,-5.939003,-1.939003,3.060997,5.060997,18.060997,34.060997,37.060997,30.060997,14.060997,8.060997,9.060997,28.060997,31.060997,28.060997,27.060997,28.060997,28.060997,28.060997,27.060997,29.060997,32.060997,35.060997,27.060997,16.060997,3.060997,15.060997,26.060997,29.060997,24.060997,18.060997,14.060997,14.060997,15.060997,-5.939003,-32.939003,-65.939,-73.939,-67.939,-23.939003,-11.939003,-31.939003,-57.939003,-75.939,-68.939,-70.939,-71.939,-35.939003,-25.939003,-42.939003,-42.939003,-41.939003,-44.939003,-46.939003,-48.939003,-54.939003,-60.939003,-67.939,-66.939,-67.939,-73.939,-75.939,-77.939,-81.939,-83.939,-84.939,-82.939,-79.939,-81.939,-79.939,-76.939,-72.939,-69.939,-67.939,-63.939003,-60.939003,-58.939003,-53.939003,-48.939003,-46.939003,-43.939003,-37.939003,-34.939003,-31.939003,-31.939003,-30.939003,-30.939003,-34.939003,-35.939003,-34.939003,-29.939003,-25.939003,-25.939003,-28.939003,-32.939003,-34.939003,-36.939003,-39.939003,-41.939003,-43.939003,-47.939003,-52.939003,-56.939003,-60.939003,-63.939003,-65.939,-70.939,-75.939,-77.939,-77.939,-78.939,-80.939,-80.939,-78.939,-78.939,-77.939,-71.939,-58.939003,-44.939003,-39.939003,-37.939003,-37.939003,-37.939003,-37.939003,-36.939003,-34.939003,-31.939003,-33.939003,-33.939003,-30.939003,-34.939003,-37.939003,-38.939003,-39.939003,-38.939003,-36.939003,-34.939003,-33.939003,-28.939003,-24.939003,-22.939003,-22.939003,-21.939003,-21.939003,-20.939003,-17.939003,-16.939003,-12.939003,-5.939003,-1.939003,-0.939003,-2.939003,-1.939003,4.060997,-19.939003,-40.939003,-49.939003,-48.939003,-45.939003,-45.939003,-42.939003,-38.939003,-37.939003,-39.939003,-44.939003,-42.939003,-36.939003,-29.939003,-22.939003,-17.939003,-16.939003,-18.939003,-20.939003,-20.939003,-20.939003,-31.939003,-32.939003,-23.939003,16.060997,41.060997,21.060997,18.060997,22.060997,12.060997,6.060997,4.060997,11.060997,18.060997,16.060997,17.060997,19.060997,16.060997,11.060997,4.060997,-11.939003,-16.939003,7.060997,12.060997,8.060997,6.060997,4.060997,2.060997,1.060997,1.060997,5.060997,-4.939003,-21.939003,-47.939003,-59.939003,-56.939003,-55.939003,-53.939003,-54.939003,-52.939003,-51.939003,-48.939003,-50.939003,-53.939003,-53.939003,-52.939003,-52.939003,-50.939003,-48.939003,-48.939003,-47.939003,-48.939003,-50.939003,-52.939003,-50.939003,-49.939003,-47.939003,-51.939003,-51.939003,-49.939003,-52.939003,-53.939003,-52.939003,-50.939003,-50.939003,-48.939003,-46.939003,-43.939003,-45.939003,-49.939003,-53.939003,-47.939003,-37.939003,-32.939003,-31.939003,-33.939003,-43.939003,-50.939003,-47.939003,-50.939003,-55.939003,-55.939003,-53.939003,-50.939003,-52.939003,-55.939003,-59.939003,-61.939003,-61.939003,-59.939003,-61.939003,-66.939,-80.939,-91.939,-88.939,-83.939,-76.939,-80.939,-79.939,-75.939,-88.939,-96.939,-85.939,-74.939,-64.939,-60.939003,-59.939003,-60.939003,-58.939003,-56.939003,-53.939003,-51.939003,-50.939003,-52.939003,-56.939003,-60.939003,-53.939003,-51.939003,-62.939003,-67.939,-67.939,-60.939003,-54.939003,-50.939003,-55.939003,-57.939003,-55.939003,-49.939003,-43.939003,-45.939003,-44.939003,-42.939003,-25.939003,-20.939003,-47.939003,-57.939003,-59.939003,-40.939003,-33.939003,-40.939003,-37.939003,-33.939003,-31.939003,-27.939003,-24.939003,-20.939003,-22.939003,-28.939003,-46.939003,-55.939003,-39.939003,-32.939003,-31.939003,-49.939003,-65.939,-78.939,-73.939,-67.939,-60.939003,-52.939003,-43.939003,-44.939003,-46.939003,-52.939003,-59.939003,-65.939,61.060997,36.060997,9.060997,41.060997,60.060997,63.060997,53.060997,43.060997,39.060997,29.060997,19.060997,14.060997,6.060997,-3.939003,-7.939003,-11.939003,-15.939003,-27.939003,-40.939003,-41.939003,-44.939003,-48.939003,-47.939003,-46.939003,-50.939003,-44.939003,-35.939003,-27.939003,-18.939003,-9.939003,-1.939003,4.060997,8.060997,16.060997,28.060997,32.060997,36.060997,42.060997,45.060997,47.060997,48.060997,43.060997,36.060997,34.060997,28.060997,20.060997,12.060997,2.060997,-14.939003,-30.939003,-44.939003,-54.939003,-62.939003,-69.939,-76.939,-83.939,-95.939,-99.939,-99.939,-99.939,-98.939,-98.939,-98.939,-97.939,-96.939,-96.939,-95.939,-91.939,-87.939,-85.939,-74.939,-63.939003,-54.939003,-44.939003,-34.939003,-27.939003,-21.939003,-16.939003,-8.939003,0.06099701,10.060997,16.060997,19.060997,5.060997,-16.939003,-45.939003,-38.939003,-32.939003,-36.939003,-18.939003,7.060997,17.060997,17.060997,9.060997,-0.939003,-7.939003,-12.939003,-17.939003,-22.939003,-35.939003,-46.939003,-54.939003,-61.939003,-68.939,-77.939,-84.939,-92.939,-99.939,-103.939,-103.939,-103.939,-103.939,-102.939,-102.939,-101.939,-101.939,-101.939,-100.939,-100.939,-99.939,-99.939,-98.939,-96.939,-97.939,-97.939,-96.939,-90.939,-85.939,-81.939,-77.939,-73.939,-65.939,-56.939003,-46.939003,-36.939003,-27.939003,-16.939003,-11.939003,-8.939003,1.060997,10.060997,17.060997,19.060997,20.060997,24.060997,25.060997,26.060997,33.060997,26.060997,5.060997,-11.939003,-21.939003,-17.939003,-8.939003,-1.939003,-2.939003,-10.939003,-23.939003,-29.939003,-35.939003,-45.939003,-57.939003,-69.939,-76.939,-83.939,-91.939,-93.939,-95.939,-101.939,-103.939,-103.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-95.939,-93.939,-88.939,-84.939,-79.939,-75.939,-71.939,-62.939003,-52.939003,-39.939003,-31.939003,-24.939003,-20.939003,-14.939003,-4.939003,2.060997,6.060997,7.060997,15.060997,25.060997,31.060997,31.060997,25.060997,29.060997,33.060997,35.060997,24.060997,8.060997,0.06099701,-4.939003,-7.939003,-0.939003,7.060997,15.060997,18.060997,20.060997,26.060997,31.060997,36.060997,42.060997,50.060997,63.060997,56.060997,40.060997,18.060997,13.060997,25.060997,45.060997,68.061,98.061,94.061,78.061,76.061,79.061,85.061,81.061,77.061,78.061,73.061,66.061,66.061,66.061,64.061,51.060997,43.060997,47.060997,29.060997,6.060997,19.060997,21.060997,10.060997,11.060997,-0.939003,-44.939003,-77.939,-102.939,-89.939,-82.939,-80.939,-85.939,-90.939,-92.939,-96.939,-93.939,-22.939003,21.060997,38.060997,32.060997,25.060997,26.060997,21.060997,13.060997,5.060997,-1.939003,-5.939003,-11.939003,-19.939003,-30.939003,-39.939003,-48.939003,-57.939003,-66.939,-75.939,-84.939,-92.939,-96.939,-98.939,-100.939,-101.939,-102.939,-102.939,-100.939,-100.939,-100.939,-100.939,-99.939,-99.939,-95.939,-90.939,-95.939,-99.939,-95.939,-88.939,-79.939,-82.939,-89.939,-102.939,-98.939,-88.939,-70.939,-51.939003,-33.939003,-29.939003,-25.939003,-19.939003,-23.939003,-28.939003,-26.939003,-14.939003,1.060997,11.060997,17.060997,21.060997,25.060997,28.060997,25.060997,25.060997,28.060997,29.060997,26.060997,17.060997,10.060997,3.060997,-1.939003,-6.939003,-13.939003,-19.939003,-27.939003,-38.939003,-44.939003,-51.939003,-65.939,-76.939,-82.939,-87.939,-93.939,-102.939,-77.939,-65.939,-92.939,-100.939,-99.939,-94.939,-93.939,-98.939,-98.939,-98.939,-97.939,-96.939,-95.939,-94.939,-91.939,-84.939,-81.939,-76.939,-66.939,-29.939003,12.060997,-1.939003,-10.939003,-17.939003,-24.939003,-25.939003,-12.939003,-5.939003,-1.939003,3.060997,8.060997,12.060997,16.060997,20.060997,20.060997,18.060997,14.060997,17.060997,16.060997,14.060997,13.060997,10.060997,-0.939003,-5.939003,-7.939003,-10.939003,-16.939003,-26.939003,-36.939003,-45.939003,-47.939003,-56.939003,-67.939,-77.939,-83.939,-87.939,-91.939,-96.939,-96.939,-99.939,-100.939,-99.939,-99.939,-98.939,-68.939,-40.939003,-24.939003,-38.939003,-59.939003,-31.939003,-0.939003,31.060997,19.060997,7.060997,9.060997,0.06099701,-11.939003,-16.939003,-22.939003,-26.939003,-28.939003,-28.939003,-28.939003,-29.939003,-31.939003,-41.939003,-48.939003,-53.939003,-49.939003,-46.939003,-50.939003,-48.939003,-44.939003,-38.939003,-35.939003,-34.939003,-18.939003,-9.939003,-18.939003,-33.939003,-49.939003,-43.939003,-31.939003,-13.939003,-10.939003,-11.939003,-13.939003,-1.939003,14.060997,1.060997,-4.939003,-5.939003,-6.939003,-1.939003,15.060997,22.060997,27.060997,42.060997,28.060997,-12.939003,-8.939003,5.060997,36.060997,40.060997,32.060997,26.060997,24.060997,26.060997,27.060997,27.060997,29.060997,35.060997,39.060997,7.060997,-19.939003,-40.939003,-5.939003,26.060997,37.060997,37.060997,30.060997,2.060997,-13.939003,-16.939003,-12.939003,-20.939003,-56.939003,-64.939,-59.939003,-26.939003,-22.939003,-44.939003,-67.939,-82.939,-74.939,-67.939,-60.939003,-40.939003,-52.939003,-97.939,-102.939,-101.939,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-98.939,-97.939,-96.939,-94.939,-90.939,-85.939,-79.939,-73.939,-63.939003,-54.939003,-51.939003,-46.939003,-39.939003,-31.939003,-25.939003,-21.939003,-13.939003,-7.939003,-4.939003,-0.939003,4.060997,5.060997,6.060997,7.060997,4.060997,0.06099701,-8.939003,-14.939003,-20.939003,-17.939003,-17.939003,-23.939003,-36.939003,-49.939003,-53.939003,-63.939003,-75.939,-84.939,-90.939,-93.939,-95.939,-96.939,-99.939,-99.939,-97.939,-98.939,-97.939,-97.939,-97.939,-95.939,-89.939,-82.939,-76.939,-70.939,-60.939003,-46.939003,-44.939003,-41.939003,-29.939003,-26.939003,-28.939003,-31.939003,-31.939003,-27.939003,-23.939003,-18.939003,-14.939003,-22.939003,-34.939003,-36.939003,-32.939003,-24.939003,-28.939003,-29.939003,-21.939003,-19.939003,-16.939003,-12.939003,-17.939003,-31.939003,-10.939003,10.060997,21.060997,21.060997,18.060997,13.060997,7.060997,0.06099701,0.06099701,4.060997,9.060997,7.060997,1.060997,-0.939003,-3.939003,-7.939003,-26.939003,-39.939003,-38.939003,-33.939003,-29.939003,-24.939003,-23.939003,-24.939003,-32.939003,-41.939003,-50.939003,-48.939003,-40.939003,-30.939003,-23.939003,-19.939003,-21.939003,-23.939003,-27.939003,-26.939003,-25.939003,-32.939003,-32.939003,-24.939003,3.060997,22.060997,10.060997,8.060997,10.060997,7.060997,4.060997,3.060997,13.060997,23.060997,26.060997,30.060997,34.060997,40.060997,34.060997,15.060997,-4.939003,-6.939003,36.060997,46.060997,41.060997,42.060997,44.060997,45.060997,40.060997,40.060997,53.060997,31.060997,-8.939003,-71.939,-103.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-101.939,-101.939,-101.939,-101.939,-100.939,-98.939,-99.939,-99.939,-72.939,-66.939,-82.939,-94.939,-100.939,-91.939,-87.939,-85.939,-88.939,-84.939,-76.939,-80.939,-83.939,-78.939,-72.939,-68.939,-65.939,-60.939003,-55.939003,-55.939003,-56.939003,-56.939003,-44.939003,-27.939003,-24.939003,-24.939003,-26.939003,-31.939003,-34.939003,-35.939003,-35.939003,-34.939003,-35.939003,-31.939003,-24.939003,-27.939003,-30.939003,-28.939003,-28.939003,-27.939003,-20.939003,-19.939003,-24.939003,-59.939003,-87.939,-93.939,-86.939,-76.939,-81.939,-82.939,-79.939,-89.939,-90.939,-64.939,-55.939003,-53.939003,-53.939003,-54.939003,-55.939003,-58.939003,-59.939003,-53.939003,-51.939003,-49.939003,-45.939003,-45.939003,-47.939003,-51.939003,-56.939003,-58.939003,-54.939003,-47.939003,-41.939003,-38.939003,-36.939003,-36.939003,-37.939003,-40.939003,-40.939003,-36.939003,-30.939003,-34.939003,-48.939003,-29.939003,-22.939003,-50.939003,-55.939003,-51.939003,-28.939003,-24.939003,-39.939003,-38.939003,-33.939003,-20.939003,-19.939003,-23.939003,-18.939003,-20.939003,-31.939003,-46.939003,-47.939003,-9.939003,-15.939003,-40.939003,-56.939003,-65.939,-68.939,-68.939,-65.939,-57.939003,-44.939003,-30.939003,-32.939003,-43.939003,-62.939003,-72.939,-79.939,13.060997,1.060997,-13.939003,-5.939003,-2.939003,-3.939003,-8.939003,-13.939003,-13.939003,-17.939003,-20.939003,-21.939003,-20.939003,-18.939003,-14.939003,-10.939003,-9.939003,-10.939003,-11.939003,-3.939003,2.060997,8.060997,13.060997,18.060997,22.060997,24.060997,23.060997,23.060997,19.060997,13.060997,13.060997,11.060997,4.060997,-0.939003,-4.939003,-11.939003,-18.939003,-25.939003,-32.939003,-38.939003,-40.939003,-47.939003,-56.939003,-57.939003,-59.939003,-62.939003,-46.939003,-31.939003,-24.939003,-49.939003,-83.939,-87.939,-87.939,-83.939,-84.939,-82.939,-74.939,-68.939,-64.939,-56.939003,-50.939003,-47.939003,-43.939003,-38.939003,-29.939003,-23.939003,-18.939003,-14.939003,-12.939003,-12.939003,-7.939003,-1.939003,0.06099701,3.060997,5.060997,2.060997,2.060997,5.060997,4.060997,3.060997,1.060997,-4.939003,-10.939003,-6.939003,4.060997,23.060997,20.060997,17.060997,23.060997,2.060997,-30.939003,-50.939003,-62.939003,-65.939,-68.939,-70.939,-73.939,-74.939,-76.939,-80.939,-84.939,-86.939,-86.939,-88.939,-94.939,-97.939,-99.939,-102.939,-102.939,-100.939,-100.939,-98.939,-97.939,-92.939,-86.939,-82.939,-77.939,-69.939,-66.939,-63.939003,-58.939003,-55.939003,-52.939003,-46.939003,-39.939003,-31.939003,-23.939003,-17.939003,-15.939003,-11.939003,-8.939003,-2.939003,2.060997,8.060997,10.060997,13.060997,17.060997,14.060997,8.060997,7.060997,5.060997,2.060997,-3.939003,-8.939003,-10.939003,-16.939003,-23.939003,-37.939003,-34.939003,-15.939003,23.060997,51.060997,47.060997,56.060997,60.060997,-21.939003,-67.939,-76.939,-78.939,-80.939,-84.939,-88.939,-92.939,-94.939,-96.939,-99.939,-100.939,-100.939,-102.939,-101.939,-99.939,-95.939,-92.939,-89.939,-84.939,-79.939,-75.939,-69.939,-60.939003,-56.939003,-51.939003,-44.939003,-38.939003,-30.939003,-23.939003,-20.939003,-19.939003,-10.939003,-4.939003,1.060997,3.060997,4.060997,7.060997,10.060997,14.060997,12.060997,9.060997,6.060997,7.060997,3.060997,-9.939003,-13.939003,-13.939003,-13.939003,-18.939003,-30.939003,-35.939003,-41.939003,-49.939003,-39.939003,-24.939003,-21.939003,-15.939003,-6.939003,32.060997,60.060997,60.060997,59.060997,58.060997,63.060997,67.061,71.061,67.061,67.061,78.061,71.061,53.060997,17.060997,9.060997,28.060997,35.060997,46.060997,69.061,65.061,51.060997,41.060997,35.060997,34.060997,30.060997,26.060997,25.060997,19.060997,11.060997,7.060997,5.060997,3.060997,-3.939003,-10.939003,-12.939003,-11.939003,-8.939003,-7.939003,-9.939003,-16.939003,-13.939003,-14.939003,-27.939003,-62.939003,-102.939,-94.939,-85.939,-73.939,-85.939,-95.939,-95.939,-93.939,-88.939,-51.939003,-33.939003,-35.939003,-42.939003,-49.939003,-52.939003,-57.939003,-64.939,-67.939,-69.939,-70.939,-72.939,-75.939,-79.939,-82.939,-85.939,-88.939,-91.939,-94.939,-97.939,-99.939,-101.939,-101.939,-99.939,-94.939,-90.939,-88.939,-83.939,-79.939,-75.939,-67.939,-58.939003,-54.939003,-49.939003,-41.939003,-66.939,-88.939,-96.939,-83.939,-64.939,-81.939,-94.939,-102.939,-84.939,-55.939003,-11.939003,5.060997,9.060997,6.060997,-2.939003,-15.939003,-16.939003,-15.939003,-7.939003,-7.939003,-9.939003,-13.939003,-17.939003,-20.939003,-21.939003,-25.939003,-36.939003,-43.939003,-46.939003,-47.939003,-51.939003,-58.939003,-63.939003,-67.939,-68.939,-70.939,-73.939,-75.939,-78.939,-81.939,-83.939,-86.939,-90.939,-94.939,-96.939,-97.939,-97.939,-98.939,-72.939,-56.939003,-72.939,-73.939,-68.939,-60.939003,-56.939003,-56.939003,-52.939003,-45.939003,-35.939003,-31.939003,-28.939003,-25.939003,-22.939003,-19.939003,-13.939003,-8.939003,-3.939003,-11.939003,-23.939003,-27.939003,-25.939003,-17.939003,-5.939003,4.060997,5.060997,3.060997,0.06099701,-5.939003,-11.939003,-17.939003,-21.939003,-24.939003,-26.939003,-32.939003,-41.939003,-46.939003,-52.939003,-58.939003,-61.939003,-64.939,-69.939,-70.939,-71.939,-72.939,-74.939,-77.939,-81.939,-84.939,-84.939,-87.939,-91.939,-92.939,-92.939,-89.939,-86.939,-83.939,-80.939,-77.939,-73.939,-62.939003,-56.939003,-51.939003,-39.939003,-29.939003,-23.939003,-21.939003,-22.939003,-22.939003,-21.939003,-17.939003,-25.939003,-29.939003,-25.939003,-25.939003,-27.939003,-28.939003,-28.939003,-27.939003,-27.939003,-25.939003,-20.939003,-20.939003,-23.939003,-16.939003,-10.939003,-6.939003,-5.939003,-3.939003,1.060997,6.060997,10.060997,16.060997,20.060997,21.060997,36.060997,41.060997,17.060997,-20.939003,-60.939003,-48.939003,-35.939003,-19.939003,-15.939003,-12.939003,-10.939003,-1.939003,9.060997,2.060997,1.060997,4.060997,-1.939003,-0.939003,16.060997,31.060997,43.060997,39.060997,16.060997,-23.939003,6.060997,34.060997,41.060997,39.060997,34.060997,35.060997,37.060997,40.060997,37.060997,37.060997,45.060997,34.060997,14.060997,-22.939003,-29.939003,-5.939003,21.060997,33.060997,4.060997,14.060997,35.060997,2.060997,-14.939003,-14.939003,2.060997,3.060997,-35.939003,-47.939003,-47.939003,-27.939003,-27.939003,-45.939003,-60.939003,-72.939,-80.939,-73.939,-60.939003,-45.939003,-57.939003,-98.939,-90.939,-79.939,-73.939,-69.939,-64.939,-58.939003,-53.939003,-50.939003,-44.939003,-38.939003,-31.939003,-29.939003,-29.939003,-24.939003,-21.939003,-20.939003,-19.939003,-17.939003,-13.939003,-13.939003,-14.939003,-16.939003,-18.939003,-21.939003,-24.939003,-29.939003,-35.939003,-38.939003,-40.939003,-44.939003,-48.939003,-53.939003,-60.939003,-65.939,-67.939,-27.939003,22.060997,23.060997,7.060997,-26.939003,-60.939003,-85.939,-86.939,-85.939,-84.939,-81.939,-78.939,-76.939,-72.939,-68.939,-65.939,-58.939003,-50.939003,-45.939003,-42.939003,-42.939003,-37.939003,-32.939003,-28.939003,-25.939003,-23.939003,-19.939003,-16.939003,-15.939003,-19.939003,-22.939003,-24.939003,-29.939003,-34.939003,-36.939003,-36.939003,-34.939003,-32.939003,-33.939003,-42.939003,-35.939003,-19.939003,-2.939003,-4.939003,-25.939003,-17.939003,-8.939003,-5.939003,-4.939003,-3.939003,-8.939003,-20.939003,-40.939003,-27.939003,-14.939003,-10.939003,-12.939003,-16.939003,-14.939003,-17.939003,-24.939003,-26.939003,-28.939003,-32.939003,-35.939003,-37.939003,-39.939003,-43.939003,-51.939003,-54.939003,-55.939003,-55.939003,-54.939003,-55.939003,-56.939003,-57.939003,-57.939003,-62.939003,-68.939,-72.939,-69.939,-62.939003,-58.939003,-56.939003,-54.939003,-54.939003,-54.939003,-53.939003,-51.939003,-50.939003,-55.939003,-60.939003,-64.939,-62.939003,-61.939003,-64.939,-64.939,-62.939003,-65.939,-66.939,-64.939,-61.939003,-57.939003,-53.939003,-52.939003,-52.939003,-49.939003,-48.939003,-49.939003,-47.939003,-41.939003,-30.939003,-31.939003,-37.939003,-36.939003,-34.939003,-31.939003,-34.939003,-34.939003,-21.939003,-31.939003,-52.939003,-84.939,-99.939,-96.939,-95.939,-94.939,-90.939,-85.939,-80.939,-79.939,-78.939,-77.939,-79.939,-82.939,-85.939,-88.939,-88.939,-67.939,-67.939,-90.939,-98.939,-101.939,-99.939,-98.939,-97.939,-98.939,-97.939,-94.939,-95.939,-96.939,-95.939,-93.939,-91.939,-90.939,-89.939,-87.939,-87.939,-87.939,-87.939,-50.939003,0.06099701,30.060997,25.060997,-14.939003,-52.939003,-79.939,-80.939,-80.939,-80.939,-80.939,-79.939,-76.939,-76.939,-77.939,-76.939,-77.939,-77.939,-74.939,-74.939,-77.939,-86.939,-92.939,-88.939,-84.939,-81.939,-83.939,-83.939,-82.939,-93.939,-99.939,-89.939,-72.939,-52.939003,-44.939003,-41.939003,-43.939003,-46.939003,-48.939003,-48.939003,-51.939003,-54.939003,-52.939003,-51.939003,-51.939003,-54.939003,-57.939003,-58.939003,-54.939003,-48.939003,-46.939003,-47.939003,-49.939003,-52.939003,-54.939003,-55.939003,-48.939003,-37.939003,-37.939003,-43.939003,-53.939003,-34.939003,-19.939003,-17.939003,-28.939003,-44.939003,-41.939003,-45.939003,-56.939003,-63.939003,-58.939003,-25.939003,-21.939003,-28.939003,-25.939003,-25.939003,-31.939003,-29.939003,-27.939003,-24.939003,-38.939003,-57.939003,-62.939003,-67.939,-73.939,-67.939,-58.939003,-48.939003,-39.939003,-32.939003,-27.939003,-35.939003,-56.939003,-67.939,-74.939,-10.939003,-14.939003,-19.939003,-17.939003,-17.939003,-18.939003,-17.939003,-16.939003,-11.939003,-9.939003,-8.939003,-6.939003,-0.939003,6.060997,11.060997,16.060997,21.060997,23.060997,22.060997,8.060997,3.060997,8.060997,10.060997,11.060997,13.060997,11.060997,6.060997,1.060997,-5.939003,-14.939003,-18.939003,-22.939003,-30.939003,-38.939003,-46.939003,-55.939003,-62.939003,-68.939,-74.939,-78.939,-79.939,-82.939,-85.939,-82.939,-79.939,-76.939,-56.939003,-36.939003,-23.939003,-39.939003,-64.939,-64.939,-60.939003,-53.939003,-49.939003,-43.939003,-32.939003,-25.939003,-21.939003,-14.939003,-9.939003,-8.939003,-6.939003,-4.939003,1.060997,5.060997,7.060997,8.060997,6.060997,2.060997,1.060997,0.06099701,-3.939003,-6.939003,-10.939003,-17.939003,-20.939003,-20.939003,-24.939003,-29.939003,-34.939003,-42.939003,-49.939003,-36.939003,4.060997,75.061,64.061,55.060997,70.061,25.060997,-41.939003,-82.939,-101.939,-100.939,-99.939,-98.939,-97.939,-96.939,-95.939,-93.939,-91.939,-87.939,-83.939,-80.939,-80.939,-79.939,-77.939,-75.939,-72.939,-68.939,-63.939003,-59.939003,-57.939003,-49.939003,-42.939003,-39.939003,-33.939003,-25.939003,-22.939003,-20.939003,-18.939003,-16.939003,-15.939003,-11.939003,-7.939003,-1.939003,0.06099701,2.060997,1.060997,2.060997,2.060997,1.060997,-0.939003,-2.939003,-3.939003,-5.939003,-6.939003,-12.939003,-20.939003,-25.939003,-30.939003,-36.939003,-42.939003,-47.939003,-51.939003,-58.939003,-65.939,-82.939,-72.939,-37.939003,41.060997,94.061,80.061,87.061,92.061,-23.939003,-86.939,-97.939,-93.939,-88.939,-87.939,-85.939,-82.939,-79.939,-77.939,-75.939,-73.939,-70.939,-69.939,-65.939,-60.939003,-55.939003,-51.939003,-46.939003,-40.939003,-34.939003,-29.939003,-24.939003,-18.939003,-16.939003,-11.939003,-5.939003,-3.939003,-0.939003,4.060997,5.060997,3.060997,6.060997,6.060997,7.060997,3.060997,-1.939003,-6.939003,-7.939003,-7.939003,-12.939003,-17.939003,-23.939003,-25.939003,-31.939003,-44.939003,-51.939003,-55.939003,-56.939003,-62.939003,-72.939,-79.939,-86.939,-95.939,-72.939,-39.939003,-30.939003,-19.939003,-7.939003,40.060997,75.061,72.061,68.061,62.060997,63.060997,63.060997,62.060997,55.060997,51.060997,56.060997,48.060997,33.060997,5.060997,-0.939003,16.060997,15.060997,17.060997,25.060997,22.060997,15.060997,6.060997,-0.939003,-2.939003,-4.939003,-6.939003,-6.939003,-7.939003,-9.939003,-13.939003,-15.939003,-14.939003,-16.939003,-17.939003,-16.939003,-13.939003,-8.939003,-3.939003,-0.939003,-0.939003,2.060997,6.060997,17.060997,-14.939003,-62.939003,-85.939,-92.939,-80.939,-90.939,-100.939,-98.939,-95.939,-91.939,-78.939,-74.939,-81.939,-86.939,-91.939,-92.939,-94.939,-97.939,-95.939,-94.939,-93.939,-91.939,-89.939,-87.939,-85.939,-82.939,-79.939,-75.939,-73.939,-72.939,-70.939,-68.939,-65.939,-61.939003,-54.939003,-48.939003,-42.939003,-39.939003,-35.939003,-31.939003,-23.939003,-15.939003,-14.939003,-10.939003,-4.939003,-37.939003,-69.939,-94.939,-88.939,-71.939,-88.939,-98.939,-102.939,-80.939,-49.939003,-6.939003,-0.939003,-10.939003,-10.939003,-1.939003,17.060997,16.060997,17.060997,27.060997,-1.939003,-42.939003,-50.939003,-56.939003,-60.939003,-63.939003,-68.939,-78.939,-84.939,-88.939,-88.939,-89.939,-91.939,-92.939,-91.939,-88.939,-85.939,-84.939,-81.939,-79.939,-77.939,-76.939,-74.939,-70.939,-68.939,-67.939,-65.939,-62.939003,-57.939003,-44.939003,-34.939003,-36.939003,-33.939003,-29.939003,-26.939003,-23.939003,-20.939003,-16.939003,-11.939003,-4.939003,-2.939003,-1.939003,0.06099701,0.06099701,-0.939003,0.06099701,-0.939003,-1.939003,-6.939003,-12.939003,-8.939003,-1.939003,8.060997,-6.939003,-18.939003,-22.939003,-26.939003,-31.939003,-39.939003,-47.939003,-54.939003,-60.939003,-64.939,-66.939,-72.939,-79.939,-83.939,-87.939,-91.939,-92.939,-91.939,-89.939,-87.939,-86.939,-82.939,-79.939,-78.939,-76.939,-75.939,-75.939,-72.939,-68.939,-65.939,-60.939003,-55.939003,-49.939003,-43.939003,-38.939003,-35.939003,-32.939003,-21.939003,-16.939003,-15.939003,-17.939003,-19.939003,-20.939003,-14.939003,-7.939003,-16.939003,-22.939003,-24.939003,-25.939003,-25.939003,-21.939003,-16.939003,-13.939003,-9.939003,-5.939003,-3.939003,1.060997,2.060997,-6.939003,-9.939003,-8.939003,6.060997,21.060997,33.060997,26.060997,21.060997,25.060997,28.060997,31.060997,33.060997,33.060997,32.060997,41.060997,40.060997,15.060997,-23.939003,-63.939003,-51.939003,-38.939003,-22.939003,-14.939003,-9.939003,-8.939003,-3.939003,5.060997,6.060997,6.060997,5.060997,-5.939003,-10.939003,-1.939003,23.060997,51.060997,41.060997,15.060997,-25.939003,10.060997,42.060997,43.060997,42.060997,38.060997,33.060997,31.060997,30.060997,30.060997,35.060997,49.060997,32.060997,2.060997,-29.939003,-21.939003,26.060997,34.060997,25.060997,-13.939003,5.060997,41.060997,7.060997,-14.939003,-25.939003,-6.939003,0.06099701,-31.939003,-40.939003,-40.939003,-25.939003,-27.939003,-43.939003,-53.939003,-65.939,-84.939,-77.939,-61.939003,-42.939003,-43.939003,-64.939,-55.939003,-44.939003,-37.939003,-34.939003,-31.939003,-28.939003,-25.939003,-23.939003,-19.939003,-16.939003,-11.939003,-13.939003,-17.939003,-15.939003,-18.939003,-22.939003,-28.939003,-31.939003,-28.939003,-29.939003,-33.939003,-38.939003,-42.939003,-47.939003,-53.939003,-59.939003,-67.939,-69.939,-72.939,-74.939,-77.939,-80.939,-84.939,-85.939,-81.939,-36.939003,18.060997,15.060997,-0.939003,-29.939003,-54.939003,-70.939,-69.939,-64.939,-58.939003,-53.939003,-48.939003,-44.939003,-41.939003,-37.939003,-34.939003,-30.939003,-26.939003,-22.939003,-22.939003,-24.939003,-22.939003,-21.939003,-22.939003,-22.939003,-23.939003,-25.939003,-27.939003,-29.939003,-35.939003,-40.939003,-47.939003,-46.939003,-42.939003,-41.939003,-40.939003,-41.939003,-42.939003,-45.939003,-56.939003,-42.939003,-17.939003,-0.939003,-4.939003,-29.939003,-21.939003,-14.939003,-15.939003,-15.939003,-14.939003,-19.939003,-26.939003,-37.939003,-31.939003,-27.939003,-26.939003,-30.939003,-36.939003,-38.939003,-42.939003,-49.939003,-54.939003,-59.939003,-66.939,-68.939,-69.939,-71.939,-75.939,-81.939,-76.939,-71.939,-72.939,-73.939,-75.939,-77.939,-79.939,-80.939,-80.939,-82.939,-84.939,-81.939,-77.939,-78.939,-78.939,-76.939,-75.939,-73.939,-71.939,-71.939,-71.939,-74.939,-78.939,-85.939,-92.939,-98.939,-96.939,-96.939,-96.939,-97.939,-96.939,-94.939,-96.939,-95.939,-93.939,-94.939,-94.939,-92.939,-89.939,-85.939,-76.939,-70.939,-75.939,-79.939,-83.939,-83.939,-82.939,-79.939,-81.939,-81.939,-73.939,-75.939,-83.939,-91.939,-94.939,-91.939,-92.939,-91.939,-85.939,-80.939,-75.939,-74.939,-70.939,-66.939,-65.939,-65.939,-68.939,-69.939,-66.939,-52.939003,-53.939003,-69.939,-74.939,-76.939,-78.939,-77.939,-75.939,-76.939,-77.939,-79.939,-80.939,-81.939,-78.939,-80.939,-84.939,-81.939,-80.939,-81.939,-81.939,-81.939,-81.939,-49.939003,-5.939003,24.060997,19.060997,-19.939003,-58.939003,-86.939,-88.939,-88.939,-88.939,-88.939,-90.939,-91.939,-89.939,-89.939,-91.939,-93.939,-95.939,-95.939,-98.939,-102.939,-99.939,-95.939,-88.939,-85.939,-85.939,-84.939,-84.939,-83.939,-94.939,-103.939,-102.939,-81.939,-52.939003,-42.939003,-39.939003,-43.939003,-46.939003,-50.939003,-54.939003,-59.939003,-64.939,-62.939003,-60.939003,-58.939003,-58.939003,-59.939003,-60.939003,-56.939003,-51.939003,-52.939003,-55.939003,-58.939003,-62.939003,-64.939,-64.939,-56.939003,-46.939003,-42.939003,-46.939003,-59.939003,-45.939003,-26.939003,-4.939003,-13.939003,-34.939003,-31.939003,-39.939003,-57.939003,-68.939,-65.939,-28.939003,-24.939003,-32.939003,-30.939003,-29.939003,-28.939003,-20.939003,-17.939003,-31.939003,-45.939003,-58.939003,-65.939,-71.939,-78.939,-64.939,-50.939003,-42.939003,-34.939003,-28.939003,-24.939003,-30.939003,-45.939003,-57.939003,-65.939,-10.939003,-10.939003,-9.939003,5.060997,15.060997,20.060997,27.060997,34.060997,45.060997,52.060997,56.060997,60.060997,66.061,73.061,69.061,69.061,78.061,76.061,63.060997,-6.939003,-42.939003,-46.939003,-56.939003,-65.939,-75.939,-83.939,-88.939,-92.939,-94.939,-95.939,-95.939,-96.939,-96.939,-97.939,-98.939,-99.939,-95.939,-87.939,-80.939,-73.939,-67.939,-59.939003,-50.939003,-39.939003,-30.939003,-21.939003,-17.939003,-13.939003,-11.939003,-1.939003,11.060997,14.060997,17.060997,21.060997,28.060997,33.060997,30.060997,29.060997,30.060997,26.060997,23.060997,18.060997,12.060997,5.060997,-2.939003,-10.939003,-17.939003,-22.939003,-28.939003,-38.939003,-48.939003,-57.939003,-66.939,-75.939,-83.939,-88.939,-91.939,-94.939,-95.939,-96.939,-97.939,-97.939,-97.939,-84.939,-16.939003,108.061,95.061,81.061,103.061,52.060997,-26.939003,-77.939,-100.939,-95.939,-92.939,-89.939,-85.939,-82.939,-79.939,-74.939,-67.939,-57.939003,-50.939003,-44.939003,-37.939003,-31.939003,-24.939003,-18.939003,-11.939003,-5.939003,5.060997,15.060997,19.060997,25.060997,29.060997,28.060997,29.060997,30.060997,32.060997,29.060997,22.060997,18.060997,14.060997,7.060997,-0.939003,-7.939003,-16.939003,-24.939003,-31.939003,-36.939003,-42.939003,-53.939003,-65.939,-77.939,-80.939,-84.939,-89.939,-93.939,-95.939,-96.939,-96.939,-97.939,-97.939,-98.939,-98.939,-99.939,-100.939,-101.939,-87.939,-58.939003,42.060997,109.061,81.061,86.061,93.061,-7.939003,-67.939,-87.939,-73.939,-59.939003,-54.939003,-48.939003,-40.939003,-32.939003,-25.939003,-20.939003,-13.939003,-5.939003,-1.939003,5.060997,12.060997,17.060997,21.060997,26.060997,30.060997,34.060997,37.060997,33.060997,26.060997,22.060997,19.060997,18.060997,7.060997,-3.939003,-10.939003,-11.939003,-13.939003,-28.939003,-42.939003,-53.939003,-62.939003,-72.939,-81.939,-85.939,-88.939,-93.939,-95.939,-96.939,-96.939,-96.939,-98.939,-98.939,-99.939,-99.939,-99.939,-100.939,-101.939,-101.939,-102.939,-75.939,-37.939003,-27.939003,-18.939003,-10.939003,25.060997,50.060997,50.060997,43.060997,32.060997,26.060997,18.060997,10.060997,7.060997,3.060997,-4.939003,-11.939003,-18.939003,-19.939003,-15.939003,-9.939003,-13.939003,-19.939003,-32.939003,-32.939003,-27.939003,-28.939003,-27.939003,-23.939003,-22.939003,-21.939003,-15.939003,-6.939003,3.060997,2.060997,5.060997,10.060997,14.060997,21.060997,36.060997,25.060997,8.060997,31.060997,48.060997,60.060997,57.060997,63.060997,92.061,67.061,16.060997,-62.939003,-101.939,-99.939,-101.939,-103.939,-103.939,-102.939,-102.939,-101.939,-100.939,-100.939,-101.939,-100.939,-95.939,-89.939,-85.939,-80.939,-75.939,-72.939,-66.939,-60.939003,-55.939003,-48.939003,-39.939003,-30.939003,-20.939003,-12.939003,-9.939003,-5.939003,1.060997,8.060997,13.060997,17.060997,24.060997,34.060997,33.060997,30.060997,32.060997,31.060997,28.060997,22.060997,20.060997,20.060997,-8.939003,-44.939003,-89.939,-102.939,-100.939,-102.939,-103.939,-103.939,-86.939,-70.939,-57.939003,-71.939,-94.939,-80.939,-23.939003,78.061,77.061,70.061,78.061,2.060997,-97.939,-98.939,-99.939,-99.939,-99.939,-100.939,-99.939,-99.939,-97.939,-93.939,-88.939,-82.939,-75.939,-67.939,-59.939003,-52.939003,-45.939003,-38.939003,-32.939003,-25.939003,-21.939003,-16.939003,-4.939003,1.060997,4.060997,7.060997,12.060997,19.060997,8.060997,2.060997,16.060997,19.060997,17.060997,6.060997,3.060997,10.060997,9.060997,5.060997,-3.939003,-9.939003,-13.939003,-17.939003,-23.939003,-29.939003,-40.939003,-51.939003,-60.939003,-15.939003,45.060997,55.060997,59.060997,59.060997,-27.939003,-95.939,-96.939,-96.939,-96.939,-97.939,-98.939,-99.939,-99.939,-99.939,-100.939,-100.939,-99.939,-93.939,-88.939,-84.939,-79.939,-72.939,-62.939003,-56.939003,-53.939003,-41.939003,-32.939003,-27.939003,-21.939003,-18.939003,-20.939003,-11.939003,0.06099701,5.060997,10.060997,14.060997,18.060997,23.060997,28.060997,26.060997,22.060997,23.060997,19.060997,10.060997,-1.939003,-11.939003,-15.939003,-15.939003,-13.939003,-11.939003,-4.939003,10.060997,17.060997,20.060997,23.060997,26.060997,30.060997,41.060997,47.060997,46.060997,60.060997,57.060997,12.060997,5.060997,14.060997,29.060997,46.060997,65.061,46.060997,28.060997,21.060997,19.060997,16.060997,10.060997,4.060997,-2.939003,-5.939003,-12.939003,-24.939003,-40.939003,-56.939003,-50.939003,-39.939003,-22.939003,-8.939003,-1.939003,-9.939003,-6.939003,2.060997,11.060997,9.060997,-1.939003,-16.939003,-30.939003,-39.939003,-1.939003,50.060997,49.060997,26.060997,-18.939003,5.060997,30.060997,44.060997,47.060997,43.060997,20.060997,6.060997,-2.939003,7.060997,22.060997,42.060997,29.060997,4.060997,-14.939003,3.060997,56.060997,33.060997,4.060997,-18.939003,9.060997,50.060997,18.060997,-15.939003,-50.939003,-37.939003,-28.939003,-42.939003,-43.939003,-37.939003,-20.939003,-21.939003,-38.939003,-48.939003,-62.939003,-85.939,-80.939,-62.939003,-33.939003,-10.939003,4.060997,3.060997,2.060997,7.060997,4.060997,-0.939003,-8.939003,-14.939003,-18.939003,-24.939003,-31.939003,-36.939003,-45.939003,-56.939003,-60.939003,-69.939,-79.939,-90.939,-96.939,-96.939,-96.939,-97.939,-97.939,-97.939,-98.939,-98.939,-98.939,-98.939,-94.939,-89.939,-85.939,-79.939,-72.939,-66.939,-60.939003,-49.939003,-40.939003,-33.939003,-41.939003,-41.939003,-31.939003,-17.939003,-5.939003,-4.939003,-1.939003,0.06099701,-0.939003,-0.939003,1.060997,-0.939003,-3.939003,-7.939003,-15.939003,-24.939003,-28.939003,-35.939003,-42.939003,-53.939003,-64.939,-71.939,-74.939,-74.939,-87.939,-92.939,-90.939,-92.939,-95.939,-97.939,-79.939,-52.939003,-46.939003,-45.939003,-50.939003,-51.939003,-53.939003,-56.939003,-45.939003,-28.939003,-29.939003,-31.939003,-36.939003,-42.939003,-48.939003,-50.939003,-50.939003,-49.939003,-44.939003,-35.939003,-22.939003,-24.939003,-26.939003,-27.939003,-33.939003,-42.939003,-57.939003,-68.939,-76.939,-82.939,-87.939,-91.939,-93.939,-93.939,-96.939,-98.939,-96.939,-91.939,-88.939,-89.939,-89.939,-90.939,-87.939,-88.939,-91.939,-87.939,-85.939,-85.939,-85.939,-85.939,-89.939,-89.939,-85.939,-83.939,-81.939,-82.939,-84.939,-86.939,-87.939,-87.939,-88.939,-87.939,-87.939,-85.939,-87.939,-91.939,-87.939,-86.939,-87.939,-89.939,-90.939,-94.939,-93.939,-90.939,-89.939,-89.939,-91.939,-91.939,-92.939,-97.939,-98.939,-97.939,-99.939,-100.939,-97.939,-99.939,-101.939,-100.939,-100.939,-100.939,-92.939,-88.939,-88.939,-91.939,-93.939,-88.939,-85.939,-84.939,-85.939,-79.939,-69.939,-59.939003,-50.939003,-46.939003,-40.939003,-33.939003,-27.939003,-23.939003,-19.939003,-21.939003,-24.939003,-27.939003,-24.939003,-19.939003,-21.939003,-25.939003,-30.939003,-35.939003,-36.939003,-27.939003,-34.939003,-45.939003,-36.939003,-35.939003,-37.939003,-36.939003,-36.939003,-38.939003,-42.939003,-45.939003,-43.939003,-41.939003,-40.939003,-48.939003,-55.939003,-58.939003,-59.939003,-59.939003,-59.939003,-63.939003,-70.939,-66.939,-65.939,-73.939,-78.939,-82.939,-84.939,-90.939,-99.939,-97.939,-94.939,-91.939,-89.939,-88.939,-85.939,-83.939,-81.939,-93.939,-103.939,-103.939,-81.939,-53.939003,-47.939003,-48.939003,-55.939003,-59.939003,-65.939,-72.939,-75.939,-77.939,-75.939,-72.939,-68.939,-63.939003,-61.939003,-63.939003,-60.939003,-56.939003,-61.939003,-62.939003,-61.939003,-65.939,-67.939,-66.939,-65.939,-63.939003,-43.939003,-43.939003,-65.939,-60.939003,-45.939003,-11.939003,-11.939003,-22.939003,0.06099701,-6.939003,-42.939003,-54.939003,-54.939003,-29.939003,-27.939003,-34.939003,-35.939003,-30.939003,-22.939003,-17.939003,-18.939003,-29.939003,-37.939003,-43.939003,-65.939,-78.939,-83.939,-59.939003,-39.939003,-39.939003,-29.939003,-18.939003,-25.939003,-29.939003,-30.939003,-42.939003,-51.939003,42.060997,26.060997,7.060997,23.060997,43.060997,66.061,68.061,68.061,72.061,74.061,74.061,77.061,79.061,78.061,76.061,75.061,78.061,75.061,64.061,-15.939003,-63.939003,-79.939,-74.939,-70.939,-72.939,-72.939,-71.939,-66.939,-60.939003,-52.939003,-51.939003,-52.939003,-52.939003,-46.939003,-37.939003,-37.939003,-33.939003,-26.939003,-22.939003,-19.939003,-17.939003,-11.939003,-3.939003,-0.939003,1.060997,2.060997,0.06099701,-1.939003,-2.939003,-1.939003,-1.939003,-2.939003,-5.939003,-8.939003,-10.939003,-14.939003,-23.939003,-28.939003,-32.939003,-36.939003,-40.939003,-45.939003,-50.939003,-55.939003,-59.939003,-63.939003,-67.939,-69.939,-72.939,-76.939,-81.939,-85.939,-89.939,-94.939,-97.939,-98.939,-99.939,-100.939,-98.939,-97.939,-96.939,-93.939,-88.939,-79.939,-28.939003,65.061,59.060997,48.060997,50.060997,20.060997,-23.939003,-45.939003,-52.939003,-41.939003,-38.939003,-34.939003,-28.939003,-24.939003,-20.939003,-16.939003,-10.939003,-1.939003,0.06099701,2.060997,5.060997,5.060997,4.060997,5.060997,5.060997,3.060997,0.06099701,-2.939003,-3.939003,-7.939003,-13.939003,-17.939003,-21.939003,-26.939003,-30.939003,-35.939003,-40.939003,-45.939003,-49.939003,-54.939003,-57.939003,-61.939003,-66.939,-70.939,-73.939,-75.939,-78.939,-84.939,-89.939,-95.939,-96.939,-98.939,-99.939,-99.939,-97.939,-93.939,-89.939,-86.939,-84.939,-83.939,-81.939,-75.939,-69.939,-68.939,-59.939003,-43.939003,9.060997,42.060997,21.060997,19.060997,23.060997,-3.939003,-17.939003,-17.939003,-11.939003,-5.939003,-0.939003,2.060997,4.060997,7.060997,8.060997,7.060997,9.060997,10.060997,3.060997,0.06099701,-1.939003,-4.939003,-7.939003,-9.939003,-17.939003,-22.939003,-22.939003,-29.939003,-37.939003,-43.939003,-46.939003,-46.939003,-51.939003,-56.939003,-61.939003,-63.939003,-64.939,-72.939,-78.939,-84.939,-88.939,-92.939,-96.939,-99.939,-100.939,-101.939,-99.939,-95.939,-93.939,-91.939,-88.939,-82.939,-76.939,-70.939,-68.939,-67.939,-62.939003,-56.939003,-47.939003,-32.939003,-16.939003,-16.939003,-12.939003,-3.939003,-1.939003,-2.939003,-0.939003,-4.939003,-8.939003,-10.939003,-11.939003,-13.939003,-10.939003,-8.939003,-9.939003,-10.939003,-11.939003,-8.939003,-6.939003,-5.939003,-1.939003,2.060997,8.060997,11.060997,13.060997,12.060997,15.060997,22.060997,26.060997,30.060997,38.060997,44.060997,50.060997,50.060997,53.060997,56.060997,54.060997,56.060997,70.061,48.060997,18.060997,44.060997,63.060997,73.061,66.061,63.060997,79.061,66.061,35.060997,-55.939003,-93.939,-76.939,-91.939,-101.939,-93.939,-77.939,-59.939003,-59.939003,-55.939003,-47.939003,-46.939003,-42.939003,-37.939003,-30.939003,-24.939003,-21.939003,-17.939003,-14.939003,-11.939003,-7.939003,-7.939003,-4.939003,-0.939003,3.060997,6.060997,8.060997,4.060997,1.060997,-1.939003,-2.939003,-4.939003,-10.939003,-14.939003,-13.939003,-19.939003,-25.939003,-31.939003,-34.939003,-37.939003,-43.939003,-46.939003,-47.939003,-60.939003,-74.939,-90.939,-79.939,-63.939003,-83.939,-95.939,-98.939,-94.939,-89.939,-85.939,-91.939,-100.939,-86.939,-32.939003,62.060997,54.060997,40.060997,36.060997,-12.939003,-71.939,-63.939003,-58.939003,-56.939003,-54.939003,-51.939003,-42.939003,-36.939003,-30.939003,-28.939003,-24.939003,-18.939003,-13.939003,-8.939003,-5.939003,-2.939003,0.06099701,0.06099701,0.06099701,-0.939003,2.060997,3.060997,4.060997,2.060997,-1.939003,-6.939003,-9.939003,-11.939003,-13.939003,-17.939003,-26.939003,-33.939003,-38.939003,-38.939003,-42.939003,-51.939003,-52.939003,-54.939003,-58.939003,-62.939003,-64.939,-66.939,-68.939,-72.939,-76.939,-81.939,-86.939,-26.939003,54.060997,59.060997,55.060997,42.060997,-28.939003,-81.939,-77.939,-74.939,-71.939,-66.939,-62.939003,-59.939003,-53.939003,-48.939003,-46.939003,-43.939003,-40.939003,-34.939003,-29.939003,-26.939003,-21.939003,-15.939003,-12.939003,-10.939003,-8.939003,-4.939003,-3.939003,-6.939003,-1.939003,0.06099701,-4.939003,-6.939003,-5.939003,-6.939003,-8.939003,-11.939003,-18.939003,-24.939003,-25.939003,-31.939003,-37.939003,-42.939003,-47.939003,-52.939003,-37.939003,-25.939003,-22.939003,-32.939003,-44.939003,-19.939003,7.060997,39.060997,36.060997,32.060997,33.060997,30.060997,26.060997,27.060997,27.060997,22.060997,26.060997,20.060997,-5.939003,-8.939003,-1.939003,0.06099701,1.060997,2.060997,-2.939003,-8.939003,-15.939003,-17.939003,-18.939003,-20.939003,-22.939003,-23.939003,-25.939003,-28.939003,-30.939003,-41.939003,-53.939003,-51.939003,-45.939003,-38.939003,-21.939003,-10.939003,-12.939003,-7.939003,1.060997,-3.939003,-5.939003,-4.939003,-6.939003,-9.939003,-10.939003,0.06099701,16.060997,37.060997,36.060997,13.060997,19.060997,27.060997,36.060997,26.060997,11.060997,4.060997,6.060997,16.060997,15.060997,15.060997,21.060997,17.060997,7.060997,-6.939003,10.060997,58.060997,40.060997,12.060997,-22.939003,-1.939003,37.060997,20.060997,0.06099701,-20.939003,-9.939003,-5.939003,-23.939003,-26.939003,-21.939003,-9.939003,-16.939003,-42.939003,-55.939003,-69.939,-89.939,-68.939,-31.939003,-26.939003,-30.939003,-43.939003,-45.939003,-47.939003,-49.939003,-52.939003,-56.939003,-61.939003,-64.939,-65.939,-69.939,-72.939,-75.939,-79.939,-84.939,-83.939,-84.939,-83.939,-84.939,-83.939,-80.939,-75.939,-69.939,-66.939,-63.939003,-58.939003,-53.939003,-50.939003,-46.939003,-41.939003,-35.939003,-33.939003,-31.939003,-28.939003,-23.939003,-22.939003,-24.939003,-29.939003,-33.939003,-28.939003,-26.939003,-26.939003,-26.939003,-27.939003,-31.939003,-35.939003,-40.939003,-41.939003,-45.939003,-51.939003,-54.939003,-58.939003,-60.939003,-64.939,-68.939,-71.939,-74.939,-77.939,-81.939,-84.939,-81.939,-78.939,-76.939,-77.939,-76.939,-71.939,-66.939,-62.939003,-58.939003,-50.939003,-41.939003,-37.939003,-36.939003,-37.939003,-35.939003,-34.939003,-32.939003,-31.939003,-31.939003,-29.939003,-26.939003,-22.939003,-25.939003,-28.939003,-29.939003,-28.939003,-26.939003,-28.939003,-32.939003,-40.939003,-44.939003,-51.939003,-62.939003,-69.939,-76.939,-83.939,-88.939,-90.939,-91.939,-93.939,-95.939,-94.939,-93.939,-93.939,-92.939,-89.939,-88.939,-88.939,-89.939,-87.939,-84.939,-85.939,-86.939,-88.939,-86.939,-85.939,-86.939,-84.939,-82.939,-85.939,-87.939,-86.939,-86.939,-85.939,-86.939,-87.939,-87.939,-88.939,-87.939,-87.939,-89.939,-90.939,-87.939,-87.939,-87.939,-84.939,-85.939,-88.939,-89.939,-88.939,-91.939,-89.939,-87.939,-86.939,-86.939,-89.939,-88.939,-88.939,-91.939,-93.939,-93.939,-93.939,-93.939,-93.939,-95.939,-96.939,-95.939,-94.939,-93.939,-91.939,-89.939,-87.939,-87.939,-87.939,-87.939,-89.939,-91.939,-93.939,-90.939,-84.939,-80.939,-78.939,-77.939,-72.939,-67.939,-62.939003,-58.939003,-54.939003,-59.939003,-62.939003,-62.939003,-59.939003,-55.939003,-59.939003,-58.939003,-54.939003,-60.939003,-62.939003,-54.939003,-57.939003,-63.939003,-58.939003,-55.939003,-50.939003,-49.939003,-48.939003,-46.939003,-38.939003,-28.939003,-27.939003,-29.939003,-33.939003,-43.939003,-50.939003,-49.939003,-48.939003,-45.939003,-47.939003,-49.939003,-53.939003,-50.939003,-48.939003,-49.939003,-51.939003,-52.939003,-55.939003,-57.939003,-58.939003,-76.939,-88.939,-86.939,-79.939,-73.939,-80.939,-83.939,-82.939,-94.939,-101.939,-95.939,-69.939,-40.939003,-44.939003,-51.939003,-59.939003,-63.939003,-67.939,-71.939,-70.939,-70.939,-73.939,-72.939,-67.939,-66.939,-65.939,-67.939,-64.939,-60.939003,-60.939003,-60.939003,-62.939003,-64.939,-66.939,-67.939,-67.939,-65.939,-49.939003,-47.939003,-61.939003,-49.939003,-35.939003,-23.939003,-21.939003,-20.939003,-3.939003,-9.939003,-40.939003,-60.939003,-65.939,-41.939003,-32.939003,-31.939003,-25.939003,-19.939003,-15.939003,-23.939003,-26.939003,-17.939003,-18.939003,-26.939003,-60.939003,-80.939,-85.939,-53.939003,-27.939003,-24.939003,-24.939003,-25.939003,-24.939003,-26.939003,-30.939003,-46.939003,-58.939003,82.061,54.060997,20.060997,31.060997,55.060997,91.061,87.061,79.061,76.061,73.061,70.061,69.061,66.061,59.060997,58.060997,56.060997,52.060997,50.060997,43.060997,-20.939003,-60.939003,-76.939,-62.939003,-49.939003,-45.939003,-40.939003,-34.939003,-26.939003,-18.939003,-9.939003,-7.939003,-8.939003,-9.939003,-2.939003,7.060997,6.060997,8.060997,11.060997,10.060997,8.060997,5.060997,6.060997,10.060997,5.060997,0.06099701,-4.939003,-7.939003,-6.939003,0.06099701,-12.939003,-31.939003,-37.939003,-43.939003,-49.939003,-57.939003,-66.939,-75.939,-83.939,-88.939,-92.939,-95.939,-97.939,-100.939,-102.939,-100.939,-99.939,-98.939,-97.939,-96.939,-95.939,-94.939,-92.939,-91.939,-90.939,-88.939,-85.939,-83.939,-80.939,-75.939,-71.939,-68.939,-63.939003,-56.939003,-51.939003,-26.939003,19.060997,17.060997,10.060997,1.060997,-8.939003,-19.939003,-15.939003,-9.939003,-0.939003,0.06099701,2.060997,8.060997,11.060997,13.060997,13.060997,15.060997,18.060997,16.060997,12.060997,11.060997,5.060997,-1.939003,-3.939003,-8.939003,-14.939003,-27.939003,-37.939003,-41.939003,-51.939003,-62.939003,-66.939,-73.939,-80.939,-87.939,-92.939,-95.939,-98.939,-102.939,-102.939,-101.939,-99.939,-97.939,-95.939,-94.939,-94.939,-93.939,-91.939,-90.939,-89.939,-87.939,-86.939,-81.939,-76.939,-71.939,-65.939,-58.939003,-52.939003,-49.939003,-48.939003,-45.939003,-36.939003,-27.939003,-25.939003,-21.939003,-18.939003,-11.939003,-9.939003,-19.939003,-23.939003,-20.939003,-2.939003,12.060997,20.060997,18.060997,16.060997,18.060997,16.060997,12.060997,11.060997,6.060997,0.06099701,-1.939003,-5.939003,-18.939003,-27.939003,-33.939003,-40.939003,-48.939003,-53.939003,-67.939,-78.939,-79.939,-86.939,-93.939,-97.939,-99.939,-97.939,-95.939,-94.939,-94.939,-95.939,-95.939,-94.939,-93.939,-91.939,-89.939,-88.939,-86.939,-84.939,-82.939,-80.939,-75.939,-69.939,-64.939,-59.939003,-55.939003,-46.939003,-38.939003,-29.939003,-25.939003,-25.939003,-19.939003,-11.939003,-1.939003,1.060997,-0.939003,-8.939003,-8.939003,0.06099701,-15.939003,-28.939003,-24.939003,-23.939003,-22.939003,-19.939003,-14.939003,-9.939003,-3.939003,2.060997,8.060997,12.060997,15.060997,8.060997,5.060997,7.060997,15.060997,29.060997,56.060997,60.060997,56.060997,53.060997,55.060997,63.060997,67.061,70.061,77.061,79.061,79.061,77.061,78.061,79.061,70.061,66.061,76.061,50.060997,17.060997,38.060997,53.060997,59.060997,49.060997,40.060997,44.060997,41.060997,27.060997,-56.939003,-75.939,-27.939003,-66.939,-96.939,-83.939,-50.939003,-14.939003,-16.939003,-12.939003,-2.939003,-2.939003,-0.939003,2.060997,6.060997,10.060997,9.060997,9.060997,10.060997,10.060997,8.060997,4.060997,3.060997,2.060997,2.060997,-0.939003,-3.939003,-10.939003,-18.939003,-26.939003,-32.939003,-39.939003,-50.939003,-59.939003,-65.939,-72.939,-79.939,-88.939,-91.939,-92.939,-95.939,-97.939,-96.939,-96.939,-95.939,-90.939,-51.939003,-8.939003,-54.939003,-83.939,-95.939,-91.939,-86.939,-81.939,-79.939,-78.939,-66.939,-30.939003,30.060997,19.060997,3.060997,-4.939003,-19.939003,-32.939003,-21.939003,-14.939003,-11.939003,-9.939003,-6.939003,2.060997,9.060997,14.060997,13.060997,14.060997,16.060997,17.060997,16.060997,14.060997,12.060997,9.060997,4.060997,-1.939003,-6.939003,-5.939003,-7.939003,-14.939003,-21.939003,-28.939003,-37.939003,-45.939003,-53.939003,-42.939003,-40.939003,-67.939,-81.939,-87.939,-76.939,-79.939,-96.939,-95.939,-94.939,-91.939,-90.939,-89.939,-88.939,-86.939,-84.939,-81.939,-80.939,-80.939,-28.939003,40.060997,39.060997,30.060997,13.060997,-22.939003,-47.939003,-41.939003,-36.939003,-32.939003,-26.939003,-22.939003,-18.939003,-11.939003,-3.939003,-1.939003,-0.939003,-0.939003,2.060997,4.060997,6.060997,8.060997,8.060997,4.060997,2.060997,2.060997,-1.939003,-7.939003,-14.939003,-11.939003,-10.939003,-17.939003,-24.939003,-32.939003,-36.939003,-42.939003,-49.939003,-62.939003,-73.939,-76.939,-82.939,-87.939,-94.939,-97.939,-96.939,-62.939003,-34.939003,-27.939003,-40.939003,-58.939003,-25.939003,7.060997,42.060997,32.060997,22.060997,22.060997,16.060997,7.060997,1.060997,-2.939003,-8.939003,-12.939003,-17.939003,-21.939003,-20.939003,-17.939003,-23.939003,-31.939003,-41.939003,-35.939003,-31.939003,-34.939003,-33.939003,-31.939003,-30.939003,-27.939003,-23.939003,-22.939003,-19.939003,-13.939003,-29.939003,-52.939003,-50.939003,-48.939003,-48.939003,-29.939003,-14.939003,-13.939003,-5.939003,3.060997,-12.939003,-16.939003,-9.939003,-4.939003,3.060997,17.060997,4.060997,-15.939003,13.060997,26.060997,23.060997,18.060997,15.060997,15.060997,-1.939003,-20.939003,-9.939003,6.060997,27.060997,18.060997,10.060997,8.060997,11.060997,14.060997,2.060997,14.060997,50.060997,36.060997,12.060997,-22.939003,-7.939003,22.060997,19.060997,13.060997,5.060997,15.060997,15.060997,-8.939003,-14.939003,-11.939003,-3.939003,-15.939003,-45.939003,-61.939003,-76.939,-94.939,-57.939003,-4.939003,-25.939003,-53.939003,-87.939,-87.939,-86.939,-92.939,-92.939,-91.939,-91.939,-90.939,-88.939,-87.939,-86.939,-84.939,-83.939,-82.939,-77.939,-72.939,-65.939,-59.939003,-53.939003,-49.939003,-42.939003,-34.939003,-32.939003,-28.939003,-22.939003,-17.939003,-14.939003,-13.939003,-8.939003,-3.939003,-5.939003,-8.939003,-11.939003,-11.939003,-12.939003,-20.939003,-19.939003,-12.939003,-1.939003,-6.939003,-28.939003,-45.939003,-58.939003,-62.939003,-69.939,-76.939,-76.939,-80.939,-88.939,-90.939,-92.939,-91.939,-89.939,-87.939,-86.939,-85.939,-84.939,-82.939,-78.939,-69.939,-64.939,-61.939003,-53.939003,-47.939003,-43.939003,-35.939003,-29.939003,-24.939003,-26.939003,-30.939003,-30.939003,-29.939003,-26.939003,-24.939003,-22.939003,-21.939003,-24.939003,-28.939003,-25.939003,-22.939003,-19.939003,-20.939003,-20.939003,-19.939003,-18.939003,-18.939003,-27.939003,-42.939003,-62.939003,-68.939,-75.939,-90.939,-95.939,-99.939,-99.939,-98.939,-96.939,-94.939,-93.939,-94.939,-92.939,-90.939,-88.939,-86.939,-83.939,-85.939,-86.939,-87.939,-84.939,-79.939,-82.939,-84.939,-85.939,-86.939,-86.939,-87.939,-84.939,-80.939,-83.939,-85.939,-85.939,-86.939,-86.939,-87.939,-88.939,-87.939,-88.939,-88.939,-87.939,-91.939,-93.939,-90.939,-87.939,-85.939,-84.939,-86.939,-89.939,-88.939,-87.939,-87.939,-86.939,-85.939,-83.939,-83.939,-86.939,-84.939,-84.939,-85.939,-88.939,-90.939,-87.939,-86.939,-89.939,-90.939,-90.939,-88.939,-88.939,-87.939,-89.939,-89.939,-87.939,-85.939,-84.939,-87.939,-91.939,-94.939,-96.939,-96.939,-95.939,-96.939,-99.939,-100.939,-98.939,-94.939,-91.939,-89.939,-88.939,-92.939,-95.939,-92.939,-91.939,-90.939,-93.939,-89.939,-81.939,-86.939,-88.939,-83.939,-82.939,-83.939,-83.939,-79.939,-71.939,-70.939,-68.939,-63.939003,-36.939003,-2.939003,2.060997,-6.939003,-29.939003,-47.939003,-59.939003,-56.939003,-53.939003,-49.939003,-52.939003,-53.939003,-53.939003,-51.939003,-49.939003,-45.939003,-44.939003,-44.939003,-46.939003,-45.939003,-41.939003,-68.939,-86.939,-81.939,-72.939,-64.939,-75.939,-81.939,-83.939,-94.939,-101.939,-91.939,-64.939,-35.939003,-44.939003,-54.939003,-63.939003,-66.939,-66.939,-67.939,-64.939,-61.939003,-68.939,-70.939,-68.939,-68.939,-68.939,-70.939,-67.939,-63.939003,-60.939003,-60.939003,-64.939,-65.939,-66.939,-68.939,-67.939,-64.939,-55.939003,-51.939003,-54.939003,-34.939003,-21.939003,-29.939003,-27.939003,-19.939003,-8.939003,-15.939003,-42.939003,-66.939,-77.939,-58.939003,-43.939003,-32.939003,-18.939003,-12.939003,-14.939003,-28.939003,-32.939003,-12.939003,-9.939003,-17.939003,-53.939003,-73.939,-77.939,-44.939003,-17.939003,-15.939003,-21.939003,-30.939003,-24.939003,-25.939003,-34.939003,-51.939003,-63.939003,75.061,52.060997,21.060997,16.060997,30.060997,66.061,57.060997,44.060997,35.060997,28.060997,21.060997,14.060997,7.060997,1.060997,-6.939003,-12.939003,-16.939003,-19.939003,-21.939003,-18.939003,-10.939003,1.060997,7.060997,11.060997,12.060997,14.060997,17.060997,14.060997,10.060997,5.060997,6.060997,6.060997,1.060997,-6.939003,-15.939003,-18.939003,-21.939003,-24.939003,-33.939003,-40.939003,-44.939003,-49.939003,-54.939003,-63.939003,-71.939,-79.939,-70.939,-51.939003,-14.939003,-37.939003,-79.939,-85.939,-89.939,-91.939,-92.939,-94.939,-97.939,-98.939,-100.939,-100.939,-100.939,-99.939,-98.939,-97.939,-92.939,-87.939,-82.939,-78.939,-75.939,-70.939,-63.939003,-56.939003,-51.939003,-45.939003,-39.939003,-32.939003,-24.939003,-15.939003,-9.939003,-3.939003,3.060997,9.060997,15.060997,15.060997,5.060997,-15.939003,-21.939003,-22.939003,-16.939003,-12.939003,-8.939003,2.060997,5.060997,-1.939003,-9.939003,-15.939003,-16.939003,-20.939003,-25.939003,-32.939003,-41.939003,-51.939003,-58.939003,-65.939,-66.939,-72.939,-79.939,-80.939,-81.939,-83.939,-85.939,-88.939,-89.939,-91.939,-93.939,-95.939,-96.939,-98.939,-99.939,-100.939,-101.939,-101.939,-101.939,-100.939,-95.939,-87.939,-77.939,-69.939,-65.939,-62.939003,-59.939003,-52.939003,-46.939003,-40.939003,-35.939003,-27.939003,-15.939003,-8.939003,-3.939003,3.060997,10.060997,19.060997,18.060997,18.060997,19.060997,21.060997,24.060997,26.060997,23.060997,16.060997,2.060997,-5.939003,-1.939003,8.060997,15.060997,-0.939003,-15.939003,-28.939003,-35.939003,-41.939003,-49.939003,-57.939003,-64.939,-70.939,-75.939,-79.939,-79.939,-80.939,-83.939,-85.939,-87.939,-89.939,-90.939,-92.939,-95.939,-97.939,-98.939,-98.939,-99.939,-97.939,-94.939,-91.939,-87.939,-82.939,-76.939,-70.939,-67.939,-65.939,-59.939003,-50.939003,-44.939003,-37.939003,-31.939003,-22.939003,-14.939003,-10.939003,-6.939003,-0.939003,6.060997,13.060997,15.060997,18.060997,20.060997,22.060997,22.060997,19.060997,18.060997,17.060997,10.060997,3.060997,-2.939003,-8.939003,-9.939003,-6.939003,2.060997,12.060997,23.060997,28.060997,30.060997,35.060997,41.060997,49.060997,55.060997,60.060997,67.061,72.061,71.061,32.060997,18.060997,30.060997,33.060997,47.060997,91.061,95.061,81.061,72.061,71.061,77.061,72.061,68.061,66.061,60.060997,54.060997,48.060997,43.060997,37.060997,27.060997,21.060997,21.060997,9.060997,-7.939003,-3.939003,-3.939003,-5.939003,-14.939003,-20.939003,-20.939003,-19.939003,-23.939003,-71.939,-39.939003,74.061,-13.939003,-87.939,-76.939,-30.939003,23.060997,11.060997,3.060997,2.060997,-6.939003,-12.939003,-15.939003,-20.939003,-26.939003,-34.939003,-41.939003,-49.939003,-57.939003,-64.939,-70.939,-75.939,-78.939,-79.939,-79.939,-80.939,-81.939,-83.939,-85.939,-87.939,-88.939,-91.939,-93.939,-94.939,-96.939,-98.939,-100.939,-96.939,-90.939,-87.939,-82.939,-75.939,-73.939,-77.939,-91.939,-18.939003,72.061,-6.939003,-63.939003,-98.939,-64.939,-28.939003,-7.939003,-4.939003,-6.939003,-0.939003,-3.939003,-17.939003,-25.939003,-28.939003,-23.939003,-4.939003,15.060997,15.060997,15.060997,17.060997,12.060997,7.060997,3.060997,-3.939003,-10.939003,-18.939003,-24.939003,-30.939003,-37.939003,-44.939003,-53.939003,-61.939003,-69.939,-76.939,-79.939,-81.939,-80.939,-81.939,-82.939,-84.939,-86.939,-88.939,-90.939,-91.939,-69.939,-59.939003,-82.939,-91.939,-92.939,-76.939,-71.939,-77.939,-70.939,-63.939003,-59.939003,-51.939003,-42.939003,-40.939003,-34.939003,-22.939003,-15.939003,-9.939003,-9.939003,-12.939003,-17.939003,-24.939003,-28.939003,-28.939003,-7.939003,10.060997,9.060997,11.060997,14.060997,13.060997,8.060997,-0.939003,3.060997,4.060997,-1.939003,-10.939003,-21.939003,-26.939003,-31.939003,-35.939003,-42.939003,-50.939003,-62.939003,-68.939,-72.939,-77.939,-81.939,-83.939,-82.939,-82.939,-83.939,-85.939,-87.939,-88.939,-89.939,-90.939,-91.939,-91.939,-91.939,-88.939,-83.939,-79.939,-75.939,-71.939,-47.939003,-26.939003,-24.939003,-25.939003,-29.939003,-23.939003,-18.939003,-14.939003,-22.939003,-28.939003,-28.939003,-31.939003,-35.939003,-37.939003,-37.939003,-36.939003,-36.939003,-33.939003,-26.939003,-25.939003,-26.939003,-23.939003,-19.939003,-16.939003,-14.939003,-9.939003,-1.939003,4.060997,9.060997,13.060997,17.060997,21.060997,29.060997,38.060997,48.060997,5.060997,-54.939003,-48.939003,-43.939003,-40.939003,-20.939003,-5.939003,-6.939003,0.06099701,9.060997,-3.939003,-14.939003,-22.939003,-25.939003,-15.939003,22.060997,5.060997,-30.939003,-35.939003,-32.939003,-21.939003,-20.939003,-23.939003,-32.939003,-39.939003,-42.939003,-15.939003,0.06099701,6.060997,6.060997,8.060997,16.060997,24.060997,30.060997,16.060997,13.060997,21.060997,4.060997,-9.939003,-10.939003,-2.939003,7.060997,14.060997,15.060997,9.060997,17.060997,16.060997,-13.939003,-20.939003,-19.939003,-15.939003,-23.939003,-41.939003,-64.939,-85.939,-98.939,-52.939003,7.060997,-37.939003,-71.939,-93.939,-83.939,-74.939,-74.939,-66.939,-55.939003,-51.939003,-45.939003,-38.939003,-35.939003,-29.939003,-19.939003,-15.939003,-13.939003,-12.939003,-11.939003,-10.939003,-7.939003,-7.939003,-7.939003,-6.939003,-6.939003,-13.939003,-17.939003,-17.939003,-22.939003,-30.939003,-41.939003,-43.939003,-42.939003,-50.939003,-58.939003,-66.939,-72.939,-73.939,-63.939003,-15.939003,42.060997,40.060997,11.060997,-44.939003,-73.939,-89.939,-84.939,-78.939,-74.939,-72.939,-69.939,-62.939003,-58.939003,-53.939003,-49.939003,-43.939003,-36.939003,-31.939003,-27.939003,-21.939003,-21.939003,-20.939003,-18.939003,-16.939003,-16.939003,-12.939003,-11.939003,-12.939003,-14.939003,-17.939003,-22.939003,-25.939003,-26.939003,-32.939003,-31.939003,-24.939003,-30.939003,-37.939003,-46.939003,-34.939003,-16.939003,-20.939003,-31.939003,-48.939003,-51.939003,-52.939003,-51.939003,-53.939003,-60.939003,-76.939,-85.939,-87.939,-89.939,-91.939,-91.939,-90.939,-90.939,-87.939,-87.939,-89.939,-89.939,-88.939,-88.939,-87.939,-87.939,-90.939,-90.939,-88.939,-87.939,-86.939,-86.939,-84.939,-82.939,-83.939,-84.939,-87.939,-87.939,-87.939,-88.939,-85.939,-82.939,-84.939,-85.939,-84.939,-83.939,-84.939,-86.939,-87.939,-89.939,-88.939,-89.939,-92.939,-92.939,-92.939,-90.939,-89.939,-88.939,-88.939,-87.939,-86.939,-87.939,-87.939,-86.939,-86.939,-87.939,-84.939,-83.939,-85.939,-85.939,-85.939,-90.939,-90.939,-89.939,-86.939,-87.939,-91.939,-90.939,-90.939,-90.939,-89.939,-90.939,-88.939,-87.939,-88.939,-89.939,-89.939,-89.939,-90.939,-92.939,-91.939,-91.939,-93.939,-95.939,-95.939,-95.939,-92.939,-89.939,-92.939,-94.939,-97.939,-94.939,-92.939,-91.939,-93.939,-98.939,-94.939,-94.939,-97.939,-97.939,-97.939,-94.939,-93.939,-94.939,-95.939,-95.939,-94.939,-93.939,-92.939,-88.939,-37.939003,25.060997,45.060997,23.060997,-40.939003,-72.939,-93.939,-92.939,-91.939,-91.939,-91.939,-91.939,-91.939,-91.939,-90.939,-90.939,-89.939,-89.939,-90.939,-90.939,-89.939,-95.939,-95.939,-79.939,-72.939,-69.939,-71.939,-76.939,-85.939,-96.939,-102.939,-100.939,-78.939,-50.939003,-53.939003,-60.939003,-70.939,-67.939,-65.939,-65.939,-60.939003,-56.939003,-62.939003,-69.939,-77.939,-72.939,-68.939,-72.939,-69.939,-65.939,-65.939,-67.939,-69.939,-68.939,-67.939,-69.939,-66.939,-60.939003,-62.939003,-56.939003,-46.939003,-18.939003,-2.939003,-21.939003,-23.939003,-18.939003,-6.939003,-17.939003,-51.939003,-71.939,-84.939,-83.939,-68.939,-44.939003,-20.939003,-15.939003,-29.939003,-29.939003,-28.939003,-24.939003,-27.939003,-32.939003,-43.939003,-48.939003,-47.939003,-27.939003,-14.939003,-22.939003,-26.939003,-28.939003,-22.939003,-27.939003,-43.939003,-51.939003,-56.939003,33.060997,21.060997,6.060997,-0.939003,3.060997,18.060997,10.060997,0.06099701,-5.939003,-9.939003,-13.939003,-16.939003,-19.939003,-21.939003,-24.939003,-25.939003,-25.939003,-24.939003,-21.939003,-12.939003,-6.939003,-3.939003,-3.939003,-4.939003,-6.939003,-8.939003,-11.939003,-18.939003,-25.939003,-30.939003,-33.939003,-37.939003,-41.939003,-49.939003,-57.939003,-59.939003,-61.939003,-63.939003,-70.939,-75.939,-76.939,-80.939,-87.939,-92.939,-97.939,-102.939,-90.939,-65.939,-20.939003,-36.939003,-75.939,-87.939,-91.939,-85.939,-84.939,-82.939,-79.939,-76.939,-74.939,-67.939,-63.939003,-59.939003,-56.939003,-53.939003,-48.939003,-43.939003,-36.939003,-30.939003,-26.939003,-23.939003,-18.939003,-12.939003,-8.939003,-5.939003,-2.939003,0.06099701,3.060997,7.060997,6.060997,7.060997,10.060997,10.060997,8.060997,1.060997,-1.939003,-0.939003,2.060997,7.060997,16.060997,12.060997,1.060997,-23.939003,-39.939003,-46.939003,-52.939003,-56.939003,-58.939003,-61.939003,-65.939,-71.939,-78.939,-86.939,-91.939,-95.939,-96.939,-99.939,-103.939,-102.939,-100.939,-98.939,-95.939,-92.939,-92.939,-89.939,-85.939,-82.939,-78.939,-72.939,-69.939,-66.939,-64.939,-59.939003,-55.939003,-51.939003,-45.939003,-37.939003,-27.939003,-19.939003,-14.939003,-13.939003,-11.939003,-6.939003,-3.939003,0.06099701,1.060997,3.060997,7.060997,9.060997,9.060997,8.060997,8.060997,9.060997,5.060997,0.06099701,-4.939003,-9.939003,-14.939003,-14.939003,-19.939003,-27.939003,6.060997,37.060997,42.060997,58.060997,69.061,17.060997,-27.939003,-67.939,-74.939,-78.939,-84.939,-89.939,-94.939,-97.939,-100.939,-100.939,-99.939,-97.939,-96.939,-92.939,-87.939,-85.939,-82.939,-80.939,-75.939,-71.939,-69.939,-64.939,-59.939003,-53.939003,-48.939003,-43.939003,-39.939003,-33.939003,-24.939003,-18.939003,-14.939003,-14.939003,-10.939003,-4.939003,-1.939003,2.060997,2.060997,4.060997,6.060997,7.060997,6.060997,4.060997,4.060997,4.060997,-1.939003,-5.939003,-10.939003,-15.939003,-18.939003,-23.939003,-25.939003,-28.939003,-37.939003,-35.939003,-29.939003,-20.939003,-15.939003,-15.939003,17.060997,48.060997,64.061,67.061,64.061,68.061,69.061,67.061,69.061,70.061,73.061,75.061,71.061,25.060997,9.060997,23.060997,23.060997,32.060997,59.060997,59.060997,45.060997,37.060997,34.060997,34.060997,29.060997,24.060997,21.060997,17.060997,10.060997,4.060997,0.06099701,-2.939003,-5.939003,-6.939003,-7.939003,-9.939003,-11.939003,-9.939003,-11.939003,-16.939003,-19.939003,-18.939003,-3.939003,-10.939003,-31.939003,-76.939,-52.939003,38.060997,-26.939003,-83.939,-83.939,-57.939003,-25.939003,-34.939003,-40.939003,-43.939003,-49.939003,-53.939003,-56.939003,-61.939003,-66.939,-71.939,-77.939,-83.939,-88.939,-93.939,-97.939,-98.939,-97.939,-96.939,-93.939,-88.939,-85.939,-83.939,-81.939,-78.939,-74.939,-69.939,-66.939,-64.939,-63.939003,-60.939003,-56.939003,-50.939003,-44.939003,-37.939003,-29.939003,-19.939003,-39.939003,-64.939,-91.939,-42.939003,26.060997,-23.939003,-64.939,-98.939,-63.939003,-25.939003,7.060997,10.060997,1.060997,-1.939003,-1.939003,0.06099701,-1.939003,-1.939003,7.060997,-3.939003,-21.939003,-28.939003,-31.939003,-30.939003,-34.939003,-38.939003,-43.939003,-49.939003,-55.939003,-60.939003,-65.939,-70.939,-75.939,-80.939,-86.939,-90.939,-95.939,-98.939,-96.939,-91.939,-89.939,-86.939,-82.939,-78.939,-75.939,-70.939,-66.939,-64.939,-49.939003,-40.939003,-48.939003,-51.939003,-50.939003,-39.939003,-31.939003,-26.939003,-20.939003,-15.939003,-13.939003,-7.939003,-1.939003,-1.939003,1.060997,6.060997,5.060997,4.060997,4.060997,-3.939003,-15.939003,-17.939003,-14.939003,-6.939003,-10.939003,-16.939003,-23.939003,-25.939003,-27.939003,-30.939003,-35.939003,-41.939003,-41.939003,-41.939003,-46.939003,-52.939003,-61.939003,-65.939,-69.939,-71.939,-77.939,-82.939,-88.939,-90.939,-91.939,-92.939,-91.939,-87.939,-82.939,-77.939,-73.939,-71.939,-69.939,-65.939,-61.939003,-59.939003,-56.939003,-51.939003,-45.939003,-41.939003,-36.939003,-31.939003,-27.939003,-23.939003,-20.939003,-17.939003,-18.939003,-16.939003,-13.939003,-16.939003,-21.939003,-25.939003,-26.939003,-27.939003,-28.939003,-27.939003,-25.939003,-22.939003,-18.939003,-13.939003,-8.939003,-4.939003,-7.939003,-6.939003,-2.939003,9.060997,23.060997,36.060997,31.060997,28.060997,35.060997,37.060997,37.060997,36.060997,36.060997,36.060997,35.060997,35.060997,38.060997,-5.939003,-62.939003,-57.939003,-53.939003,-51.939003,-35.939003,-24.939003,-29.939003,-28.939003,-25.939003,-34.939003,-43.939003,-52.939003,-54.939003,-48.939003,-26.939003,-35.939003,-57.939003,-64.939,-65.939,-59.939003,-57.939003,-58.939003,-67.939,-68.939,-65.939,-51.939003,-43.939003,-43.939003,-41.939003,-38.939003,-30.939003,-23.939003,-20.939003,-22.939003,-24.939003,-27.939003,-32.939003,-24.939003,10.060997,20.060997,15.060997,-8.939003,-16.939003,-7.939003,5.060997,6.060997,-22.939003,-28.939003,-25.939003,-25.939003,-30.939003,-39.939003,-64.939,-84.939,-91.939,-57.939003,-13.939003,-33.939003,-45.939003,-47.939003,-40.939003,-34.939003,-32.939003,-27.939003,-22.939003,-20.939003,-18.939003,-17.939003,-19.939003,-19.939003,-16.939003,-18.939003,-21.939003,-21.939003,-25.939003,-34.939003,-38.939003,-41.939003,-42.939003,-42.939003,-43.939003,-49.939003,-51.939003,-52.939003,-56.939003,-61.939003,-68.939,-68.939,-65.939,-67.939,-71.939,-75.939,-78.939,-75.939,-62.939003,-24.939003,16.060997,12.060997,-4.939003,-33.939003,-48.939003,-55.939003,-50.939003,-44.939003,-39.939003,-38.939003,-35.939003,-31.939003,-30.939003,-28.939003,-28.939003,-25.939003,-20.939003,-18.939003,-18.939003,-18.939003,-22.939003,-25.939003,-26.939003,-28.939003,-31.939003,-35.939003,-38.939003,-41.939003,-45.939003,-51.939003,-57.939003,-56.939003,-51.939003,-49.939003,-47.939003,-45.939003,-52.939003,-60.939003,-70.939,-61.939003,-48.939003,-53.939003,-61.939003,-74.939,-76.939,-77.939,-78.939,-79.939,-82.939,-91.939,-94.939,-91.939,-91.939,-92.939,-91.939,-89.939,-87.939,-84.939,-85.939,-87.939,-87.939,-87.939,-88.939,-87.939,-86.939,-89.939,-90.939,-88.939,-88.939,-86.939,-85.939,-84.939,-83.939,-83.939,-84.939,-86.939,-88.939,-88.939,-90.939,-88.939,-86.939,-87.939,-86.939,-83.939,-83.939,-83.939,-85.939,-86.939,-86.939,-86.939,-89.939,-94.939,-91.939,-88.939,-89.939,-88.939,-87.939,-88.939,-88.939,-86.939,-86.939,-87.939,-87.939,-87.939,-88.939,-86.939,-85.939,-85.939,-86.939,-88.939,-90.939,-91.939,-90.939,-88.939,-87.939,-89.939,-90.939,-91.939,-91.939,-90.939,-90.939,-87.939,-86.939,-87.939,-88.939,-89.939,-89.939,-90.939,-90.939,-88.939,-89.939,-93.939,-94.939,-94.939,-92.939,-90.939,-88.939,-90.939,-92.939,-94.939,-92.939,-91.939,-90.939,-93.939,-96.939,-92.939,-92.939,-97.939,-96.939,-94.939,-90.939,-85.939,-79.939,-74.939,-68.939,-64.939,-64.939,-65.939,-66.939,-40.939003,-6.939003,6.060997,-4.939003,-40.939003,-60.939003,-73.939,-74.939,-75.939,-76.939,-77.939,-77.939,-78.939,-79.939,-81.939,-86.939,-86.939,-85.939,-88.939,-91.939,-94.939,-95.939,-93.939,-81.939,-75.939,-71.939,-70.939,-73.939,-83.939,-94.939,-102.939,-103.939,-91.939,-75.939,-69.939,-69.939,-75.939,-75.939,-74.939,-71.939,-69.939,-68.939,-67.939,-71.939,-77.939,-75.939,-71.939,-66.939,-63.939003,-62.939003,-61.939003,-62.939003,-63.939003,-61.939003,-60.939003,-63.939003,-58.939003,-51.939003,-58.939003,-61.939003,-61.939003,-34.939003,-16.939003,-26.939003,-29.939003,-28.939003,-19.939003,-31.939003,-66.939,-79.939,-83.939,-73.939,-67.939,-59.939003,-40.939003,-30.939003,-29.939003,-21.939003,-16.939003,-19.939003,-19.939003,-19.939003,-24.939003,-30.939003,-37.939003,-24.939003,-16.939003,-27.939003,-34.939003,-37.939003,-35.939003,-38.939003,-45.939003,-46.939003,-46.939003,-17.939003,-15.939003,-13.939003,-15.939003,-20.939003,-28.939003,-35.939003,-37.939003,-34.939003,-33.939003,-33.939003,-30.939003,-26.939003,-22.939003,-16.939003,-9.939003,-5.939003,2.060997,8.060997,-2.939003,-20.939003,-44.939003,-49.939003,-50.939003,-55.939003,-63.939003,-71.939,-80.939,-84.939,-83.939,-90.939,-97.939,-100.939,-101.939,-100.939,-100.939,-100.939,-98.939,-99.939,-99.939,-94.939,-95.939,-99.939,-97.939,-97.939,-97.939,-86.939,-63.939003,-21.939003,-24.939003,-47.939003,-65.939,-69.939,-56.939003,-54.939003,-50.939003,-43.939003,-36.939003,-32.939003,-18.939003,-10.939003,-6.939003,-1.939003,2.060997,3.060997,6.060997,12.060997,19.060997,22.060997,21.060997,22.060997,22.060997,23.060997,21.060997,18.060997,13.060997,9.060997,4.060997,-5.939003,-12.939003,-14.939003,-21.939003,-31.939003,-45.939003,-23.939003,34.060997,51.060997,61.060997,66.061,48.060997,13.060997,-64.939,-103.939,-101.939,-102.939,-101.939,-101.939,-101.939,-101.939,-100.939,-99.939,-99.939,-98.939,-98.939,-98.939,-97.939,-96.939,-94.939,-90.939,-85.939,-78.939,-72.939,-71.939,-64.939,-55.939003,-50.939003,-40.939003,-29.939003,-22.939003,-15.939003,-11.939003,-3.939003,4.060997,9.060997,14.060997,18.060997,24.060997,28.060997,32.060997,32.060997,30.060997,29.060997,27.060997,26.060997,20.060997,12.060997,3.060997,-2.939003,-7.939003,-16.939003,-25.939003,-34.939003,-42.939003,-51.939003,-63.939003,-75.939,-85.939,-88.939,-92.939,-96.939,6.060997,88.061,90.061,108.061,122.061,39.060997,-32.939003,-94.939,-99.939,-98.939,-96.939,-96.939,-96.939,-96.939,-93.939,-90.939,-86.939,-82.939,-81.939,-71.939,-60.939003,-55.939003,-50.939003,-46.939003,-35.939003,-27.939003,-23.939003,-14.939003,-4.939003,3.060997,8.060997,12.060997,16.060997,21.060997,29.060997,33.060997,35.060997,32.060997,30.060997,30.060997,28.060997,25.060997,16.060997,8.060997,-1.939003,-3.939003,-9.939003,-20.939003,-30.939003,-40.939003,-53.939003,-64.939,-73.939,-85.939,-92.939,-94.939,-96.939,-97.939,-101.939,-85.939,-61.939003,-34.939003,-21.939003,-24.939003,27.060997,73.061,91.061,90.061,79.061,82.061,76.061,62.060997,58.060997,56.060997,52.060997,49.060997,43.060997,6.060997,-6.939003,3.060997,3.060997,3.060997,4.060997,-1.939003,-11.939003,-14.939003,-18.939003,-23.939003,-26.939003,-28.939003,-27.939003,-28.939003,-31.939003,-35.939003,-35.939003,-31.939003,-25.939003,-19.939003,-16.939003,-11.939003,-5.939003,1.060997,5.060997,2.060997,4.060997,15.060997,44.060997,21.060997,-25.939003,-76.939,-84.939,-47.939003,-64.939,-82.939,-93.939,-99.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-98.939,-98.939,-98.939,-97.939,-97.939,-96.939,-95.939,-90.939,-83.939,-81.939,-74.939,-63.939003,-57.939003,-52.939003,-47.939003,-40.939003,-32.939003,-22.939003,-16.939003,-12.939003,-9.939003,-4.939003,3.060997,8.060997,12.060997,20.060997,29.060997,38.060997,-8.939003,-56.939003,-92.939,-85.939,-63.939003,-63.939003,-74.939,-95.939,-74.939,-45.939003,-6.939003,-5.939003,-20.939003,-33.939003,-11.939003,45.060997,47.060997,46.060997,54.060997,-6.939003,-88.939,-98.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-98.939,-98.939,-97.939,-97.939,-97.939,-97.939,-96.939,-93.939,-90.939,-87.939,-81.939,-69.939,-64.939,-59.939003,-51.939003,-42.939003,-34.939003,-24.939003,-17.939003,-11.939003,-12.939003,-10.939003,-0.939003,3.060997,4.060997,4.060997,12.060997,26.060997,28.060997,27.060997,26.060997,24.060997,24.060997,22.060997,18.060997,11.060997,-0.939003,-9.939003,-11.939003,-1.939003,12.060997,19.060997,28.060997,39.060997,-21.939003,-74.939,-87.939,-93.939,-97.939,-99.939,-100.939,-99.939,-100.939,-99.939,-99.939,-99.939,-98.939,-97.939,-97.939,-96.939,-96.939,-94.939,-89.939,-84.939,-80.939,-75.939,-68.939,-60.939003,-49.939003,-40.939003,-32.939003,-27.939003,-23.939003,-14.939003,-7.939003,-3.939003,-0.939003,5.060997,15.060997,18.060997,16.060997,19.060997,21.060997,21.060997,4.060997,-9.939003,-12.939003,-12.939003,-8.939003,-10.939003,-12.939003,-12.939003,-6.939003,-2.939003,-3.939003,0.06099701,7.060997,15.060997,21.060997,27.060997,37.060997,36.060997,16.060997,16.060997,26.060997,46.060997,67.061,87.061,72.061,59.060997,59.060997,54.060997,47.060997,40.060997,34.060997,29.060997,16.060997,3.060997,-5.939003,-36.939003,-73.939,-70.939,-68.939,-69.939,-60.939003,-55.939003,-64.939,-69.939,-73.939,-76.939,-80.939,-85.939,-84.939,-85.939,-91.939,-89.939,-84.939,-84.939,-85.939,-88.939,-88.939,-89.939,-93.939,-91.939,-87.939,-94.939,-97.939,-98.939,-97.939,-96.939,-91.939,-90.939,-90.939,-78.939,-74.939,-80.939,-68.939,-38.939003,26.060997,40.060997,25.060997,-38.939003,-57.939003,-30.939003,-10.939003,-4.939003,-32.939003,-36.939003,-31.939003,-34.939003,-37.939003,-39.939003,-62.939003,-80.939,-80.939,-63.939003,-40.939003,-25.939003,-9.939003,8.060997,8.060997,6.060997,7.060997,4.060997,-0.939003,-3.939003,-7.939003,-15.939003,-23.939003,-33.939003,-40.939003,-50.939003,-60.939003,-60.939003,-69.939,-86.939,-94.939,-99.939,-99.939,-99.939,-99.939,-99.939,-97.939,-95.939,-94.939,-92.939,-89.939,-85.939,-79.939,-73.939,-68.939,-65.939,-62.939003,-57.939003,-46.939003,-41.939003,-39.939003,-42.939003,-36.939003,-18.939003,-10.939003,-4.939003,-3.939003,-1.939003,-0.939003,-0.939003,-3.939003,-8.939003,-11.939003,-16.939003,-22.939003,-24.939003,-25.939003,-28.939003,-34.939003,-43.939003,-51.939003,-58.939003,-61.939003,-64.939,-70.939,-81.939,-87.939,-89.939,-94.939,-99.939,-101.939,-96.939,-85.939,-73.939,-70.939,-74.939,-79.939,-84.939,-90.939,-92.939,-93.939,-96.939,-97.939,-96.939,-96.939,-97.939,-100.939,-97.939,-93.939,-91.939,-89.939,-85.939,-85.939,-86.939,-89.939,-88.939,-86.939,-85.939,-86.939,-88.939,-86.939,-86.939,-90.939,-89.939,-86.939,-88.939,-88.939,-87.939,-88.939,-87.939,-85.939,-83.939,-82.939,-83.939,-84.939,-85.939,-87.939,-89.939,-91.939,-90.939,-89.939,-90.939,-88.939,-84.939,-84.939,-84.939,-84.939,-84.939,-83.939,-84.939,-88.939,-95.939,-90.939,-85.939,-87.939,-86.939,-84.939,-87.939,-88.939,-87.939,-87.939,-87.939,-89.939,-88.939,-87.939,-88.939,-88.939,-85.939,-88.939,-91.939,-89.939,-90.939,-91.939,-90.939,-88.939,-87.939,-90.939,-92.939,-93.939,-91.939,-89.939,-87.939,-86.939,-85.939,-86.939,-88.939,-90.939,-90.939,-89.939,-86.939,-87.939,-92.939,-93.939,-93.939,-91.939,-89.939,-87.939,-87.939,-87.939,-88.939,-90.939,-92.939,-91.939,-91.939,-90.939,-88.939,-88.939,-90.939,-90.939,-89.939,-80.939,-71.939,-60.939003,-46.939003,-32.939003,-19.939003,-20.939003,-24.939003,-30.939003,-40.939003,-53.939003,-55.939003,-49.939003,-36.939003,-36.939003,-36.939003,-38.939003,-40.939003,-43.939003,-44.939003,-44.939003,-46.939003,-47.939003,-51.939003,-62.939003,-62.939003,-59.939003,-66.939,-73.939,-79.939,-83.939,-86.939,-85.939,-79.939,-72.939,-70.939,-73.939,-81.939,-92.939,-101.939,-101.939,-100.939,-97.939,-85.939,-79.939,-79.939,-85.939,-87.939,-78.939,-80.939,-85.939,-77.939,-74.939,-76.939,-77.939,-73.939,-58.939003,-57.939003,-59.939003,-54.939003,-54.939003,-55.939003,-53.939003,-53.939003,-57.939003,-52.939003,-43.939003,-53.939003,-67.939,-84.939,-63.939003,-42.939003,-37.939003,-38.939003,-43.939003,-35.939003,-47.939003,-79.939,-84.939,-78.939,-51.939003,-56.939003,-72.939,-62.939003,-47.939003,-23.939003,-10.939003,-2.939003,-8.939003,-4.939003,0.06099701,-6.939003,-18.939003,-35.939003,-26.939003,-21.939003,-31.939003,-40.939003,-48.939003,-49.939003,-48.939003,-45.939003,-39.939003,-35.939003,-5.939003,-4.939003,-4.939003,-6.939003,-3.939003,8.060997,9.060997,11.060997,20.060997,23.060997,25.060997,29.060997,33.060997,37.060997,42.060997,47.060997,46.060997,57.060997,67.061,25.060997,-23.939003,-79.939,-85.939,-85.939,-87.939,-90.939,-92.939,-95.939,-95.939,-92.939,-91.939,-90.939,-92.939,-90.939,-85.939,-82.939,-78.939,-74.939,-70.939,-66.939,-60.939003,-57.939003,-55.939003,-48.939003,-43.939003,-39.939003,-32.939003,-22.939003,-7.939003,-3.939003,-4.939003,-8.939003,-5.939003,3.060997,6.060997,9.060997,11.060997,11.060997,9.060997,11.060997,12.060997,11.060997,7.060997,1.060997,-10.939003,-15.939003,-16.939003,-25.939003,-31.939003,-35.939003,-38.939003,-44.939003,-52.939003,-57.939003,-62.939003,-64.939,-65.939,-67.939,-70.939,-73.939,-73.939,-76.939,-79.939,-84.939,-41.939003,49.060997,72.061,82.061,79.061,69.061,46.060997,-53.939003,-101.939,-96.939,-89.939,-84.939,-85.939,-82.939,-78.939,-72.939,-64.939,-54.939003,-49.939003,-45.939003,-43.939003,-38.939003,-31.939003,-27.939003,-20.939003,-12.939003,-7.939003,-3.939003,0.06099701,5.060997,10.060997,12.060997,14.060997,15.060997,11.060997,9.060997,9.060997,7.060997,3.060997,-1.939003,-7.939003,-13.939003,-17.939003,-22.939003,-28.939003,-32.939003,-38.939003,-44.939003,-51.939003,-59.939003,-62.939003,-64.939,-67.939,-69.939,-71.939,-74.939,-77.939,-80.939,-83.939,-86.939,-90.939,-94.939,-97.939,-98.939,-99.939,-98.939,-4.939003,71.061,77.061,86.061,89.061,30.060997,-21.939003,-69.939,-60.939003,-51.939003,-51.939003,-45.939003,-36.939003,-33.939003,-26.939003,-19.939003,-13.939003,-9.939003,-8.939003,-4.939003,0.06099701,5.060997,7.060997,8.060997,9.060997,9.060997,7.060997,5.060997,2.060997,0.06099701,-2.939003,-3.939003,-10.939003,-17.939003,-24.939003,-27.939003,-31.939003,-38.939003,-45.939003,-52.939003,-56.939003,-59.939003,-63.939003,-66.939,-69.939,-70.939,-72.939,-75.939,-79.939,-82.939,-86.939,-90.939,-93.939,-96.939,-98.939,-98.939,-94.939,-91.939,-95.939,-80.939,-56.939003,-27.939003,-16.939003,-20.939003,5.060997,28.060997,35.060997,27.060997,12.060997,10.060997,7.060997,6.060997,1.060997,-2.939003,-7.939003,-10.939003,-12.939003,-12.939003,-11.939003,-11.939003,-10.939003,-10.939003,-14.939003,-14.939003,-13.939003,-11.939003,-9.939003,-6.939003,-1.939003,2.060997,7.060997,10.060997,13.060997,19.060997,23.060997,26.060997,27.060997,33.060997,50.060997,42.060997,25.060997,35.060997,50.060997,70.061,72.061,70.061,63.060997,4.060997,-72.939,-88.939,-84.939,-60.939003,-70.939,-80.939,-83.939,-85.939,-87.939,-89.939,-86.939,-79.939,-73.939,-69.939,-67.939,-60.939003,-51.939003,-47.939003,-44.939003,-40.939003,-36.939003,-30.939003,-22.939003,-15.939003,-10.939003,-6.939003,-3.939003,1.060997,5.060997,7.060997,6.060997,9.060997,13.060997,12.060997,10.060997,8.060997,5.060997,2.060997,-3.939003,-7.939003,-10.939003,-13.939003,-15.939003,-18.939003,-52.939003,-81.939,-94.939,-84.939,-67.939,-74.939,-81.939,-88.939,-84.939,-78.939,-69.939,-70.939,-75.939,-80.939,-32.939003,68.061,66.061,58.060997,60.060997,-5.939003,-89.939,-93.939,-90.939,-81.939,-76.939,-71.939,-66.939,-59.939003,-51.939003,-48.939003,-42.939003,-35.939003,-34.939003,-32.939003,-25.939003,-20.939003,-15.939003,-12.939003,-7.939003,-0.939003,-0.939003,-0.939003,2.060997,5.060997,7.060997,6.060997,4.060997,3.060997,-1.939003,-5.939003,-9.939003,-11.939003,-12.939003,-20.939003,-28.939003,-35.939003,-38.939003,-41.939003,-47.939003,-54.939003,-60.939003,-61.939003,-62.939003,-65.939,-69.939,-72.939,-72.939,-20.939003,49.060997,61.060997,64.061,58.060997,-26.939003,-93.939,-96.939,-91.939,-82.939,-80.939,-75.939,-69.939,-66.939,-62.939003,-58.939003,-54.939003,-50.939003,-40.939003,-34.939003,-30.939003,-26.939003,-21.939003,-19.939003,-15.939003,-12.939003,-6.939003,-1.939003,0.06099701,-0.939003,-0.939003,0.06099701,1.060997,3.060997,3.060997,0.06099701,-6.939003,-8.939003,-10.939003,-13.939003,-20.939003,-29.939003,-31.939003,-36.939003,-45.939003,-38.939003,-28.939003,-16.939003,-29.939003,-46.939003,-20.939003,8.060997,40.060997,44.060997,44.060997,45.060997,43.060997,41.060997,43.060997,42.060997,37.060997,32.060997,20.060997,-8.939003,-16.939003,-14.939003,-8.939003,-5.939003,-4.939003,-22.939003,-37.939003,-38.939003,-44.939003,-51.939003,-51.939003,-53.939003,-58.939003,-63.939003,-67.939,-70.939,-77.939,-84.939,-81.939,-79.939,-80.939,-79.939,-80.939,-82.939,-82.939,-80.939,-81.939,-81.939,-84.939,-81.939,-80.939,-82.939,-79.939,-75.939,-79.939,-80.939,-79.939,-81.939,-84.939,-91.939,-88.939,-82.939,-85.939,-86.939,-87.939,-86.939,-85.939,-85.939,-85.939,-85.939,-84.939,-83.939,-83.939,-81.939,-73.939,-52.939003,-40.939003,-33.939003,-53.939003,-54.939003,-33.939003,-8.939003,0.06099701,-29.939003,-35.939003,-30.939003,-31.939003,-37.939003,-49.939003,-65.939,-77.939,-76.939,-42.939003,-1.939003,-25.939003,-40.939003,-48.939003,-49.939003,-52.939003,-60.939003,-65.939,-69.939,-70.939,-71.939,-74.939,-76.939,-80.939,-82.939,-85.939,-87.939,-80.939,-76.939,-75.939,-68.939,-61.939003,-55.939003,-56.939003,-60.939003,-62.939003,-62.939003,-60.939003,-63.939003,-66.939,-69.939,-71.939,-71.939,-70.939,-71.939,-73.939,-75.939,-75.939,-74.939,-72.939,-70.939,-75.939,-73.939,-64.939,-59.939003,-56.939003,-55.939003,-54.939003,-53.939003,-55.939003,-59.939003,-63.939003,-66.939,-69.939,-73.939,-75.939,-76.939,-77.939,-78.939,-82.939,-85.939,-88.939,-88.939,-89.939,-90.939,-89.939,-88.939,-90.939,-93.939,-96.939,-94.939,-91.939,-88.939,-88.939,-88.939,-87.939,-86.939,-86.939,-87.939,-85.939,-83.939,-85.939,-86.939,-85.939,-86.939,-87.939,-91.939,-90.939,-86.939,-86.939,-86.939,-85.939,-84.939,-85.939,-87.939,-86.939,-85.939,-85.939,-87.939,-90.939,-89.939,-88.939,-90.939,-90.939,-90.939,-91.939,-91.939,-92.939,-89.939,-86.939,-85.939,-83.939,-81.939,-85.939,-87.939,-88.939,-90.939,-91.939,-91.939,-90.939,-89.939,-89.939,-87.939,-85.939,-84.939,-83.939,-85.939,-84.939,-83.939,-83.939,-85.939,-88.939,-85.939,-84.939,-86.939,-87.939,-89.939,-90.939,-91.939,-92.939,-90.939,-89.939,-90.939,-89.939,-87.939,-89.939,-88.939,-84.939,-88.939,-91.939,-90.939,-89.939,-90.939,-93.939,-92.939,-88.939,-91.939,-92.939,-92.939,-90.939,-90.939,-90.939,-89.939,-88.939,-87.939,-88.939,-89.939,-89.939,-89.939,-88.939,-89.939,-92.939,-91.939,-90.939,-90.939,-88.939,-87.939,-88.939,-90.939,-92.939,-90.939,-89.939,-90.939,-89.939,-87.939,-89.939,-89.939,-90.939,-93.939,-94.939,-93.939,-91.939,-89.939,-83.939,-69.939,-48.939003,-51.939003,-53.939003,-45.939003,-25.939003,-5.939003,-12.939003,-23.939003,-40.939003,-47.939003,-50.939003,-50.939003,-50.939003,-50.939003,-51.939003,-47.939003,-39.939003,-36.939003,-36.939003,-42.939003,-44.939003,-44.939003,-49.939003,-50.939003,-49.939003,-63.939003,-77.939,-88.939,-84.939,-75.939,-78.939,-78.939,-78.939,-91.939,-98.939,-86.939,-72.939,-59.939003,-73.939,-81.939,-83.939,-82.939,-78.939,-72.939,-75.939,-81.939,-78.939,-78.939,-81.939,-75.939,-68.939,-59.939003,-62.939003,-68.939,-68.939,-69.939,-72.939,-74.939,-75.939,-72.939,-73.939,-75.939,-82.939,-85.939,-83.939,-70.939,-55.939003,-42.939003,-43.939003,-48.939003,-39.939003,-41.939003,-54.939003,-55.939003,-55.939003,-50.939003,-55.939003,-60.939003,-45.939003,-34.939003,-27.939003,-21.939003,-17.939003,-19.939003,-17.939003,-15.939003,-24.939003,-31.939003,-35.939003,-29.939003,-23.939003,-22.939003,-29.939003,-37.939003,-32.939003,-31.939003,-35.939003,-28.939003,-22.939003,22.060997,21.060997,18.060997,3.060997,12.060997,45.060997,49.060997,52.060997,59.060997,62.060997,65.061,67.061,71.061,75.061,76.061,75.061,70.061,81.061,91.061,36.060997,-24.939003,-91.939,-92.939,-88.939,-88.939,-87.939,-86.939,-83.939,-78.939,-71.939,-65.939,-60.939003,-59.939003,-55.939003,-48.939003,-44.939003,-39.939003,-36.939003,-29.939003,-23.939003,-17.939003,-15.939003,-13.939003,-8.939003,-3.939003,-0.939003,-0.939003,-0.939003,-0.939003,3.060997,6.060997,5.060997,6.060997,8.060997,6.060997,5.060997,3.060997,0.06099701,-5.939003,-10.939003,-15.939003,-18.939003,-23.939003,-31.939003,-44.939003,-50.939003,-53.939003,-64.939,-72.939,-75.939,-80.939,-85.939,-94.939,-99.939,-102.939,-103.939,-102.939,-102.939,-100.939,-98.939,-95.939,-93.939,-91.939,-91.939,-50.939003,33.060997,57.060997,66.061,55.060997,49.060997,39.060997,-34.939003,-67.939,-60.939003,-52.939003,-46.939003,-44.939003,-40.939003,-33.939003,-28.939003,-20.939003,-11.939003,-7.939003,-3.939003,-2.939003,1.060997,4.060997,5.060997,7.060997,11.060997,11.060997,13.060997,15.060997,14.060997,11.060997,8.060997,3.060997,-2.939003,-9.939003,-15.939003,-18.939003,-23.939003,-29.939003,-37.939003,-45.939003,-52.939003,-58.939003,-65.939,-72.939,-77.939,-82.939,-88.939,-95.939,-102.939,-103.939,-102.939,-101.939,-99.939,-97.939,-95.939,-92.939,-89.939,-86.939,-83.939,-82.939,-80.939,-78.939,-74.939,-69.939,-65.939,-10.939003,34.060997,37.060997,39.060997,37.060997,11.060997,-9.939003,-26.939003,-15.939003,-6.939003,-7.939003,-2.939003,5.060997,3.060997,4.060997,6.060997,8.060997,9.060997,10.060997,7.060997,5.060997,4.060997,2.060997,-1.939003,-6.939003,-12.939003,-15.939003,-21.939003,-27.939003,-33.939003,-37.939003,-39.939003,-47.939003,-56.939003,-66.939,-71.939,-76.939,-82.939,-89.939,-97.939,-100.939,-102.939,-101.939,-98.939,-95.939,-95.939,-93.939,-90.939,-87.939,-84.939,-81.939,-79.939,-76.939,-72.939,-69.939,-67.939,-60.939003,-54.939003,-55.939003,-44.939003,-29.939003,-14.939003,-8.939003,-10.939003,-7.939003,-5.939003,-4.939003,-11.939003,-19.939003,-20.939003,-18.939003,-13.939003,-16.939003,-18.939003,-18.939003,-17.939003,-15.939003,-9.939003,-6.939003,-5.939003,-4.939003,-1.939003,8.060997,14.060997,19.060997,19.060997,22.060997,29.060997,36.060997,42.060997,46.060997,48.060997,51.060997,57.060997,60.060997,62.060997,60.060997,63.060997,80.061,63.060997,36.060997,43.060997,57.060997,79.061,80.061,74.061,54.060997,-15.939003,-97.939,-92.939,-84.939,-73.939,-77.939,-80.939,-79.939,-81.939,-85.939,-78.939,-60.939003,-32.939003,-28.939003,-26.939003,-22.939003,-16.939003,-8.939003,-6.939003,-4.939003,-2.939003,-0.939003,2.060997,6.060997,8.060997,9.060997,12.060997,11.060997,8.060997,8.060997,5.060997,-0.939003,-2.939003,-4.939003,-11.939003,-16.939003,-19.939003,-23.939003,-27.939003,-37.939003,-42.939003,-47.939003,-53.939003,-58.939003,-63.939003,-82.939,-94.939,-90.939,-77.939,-63.939003,-72.939,-80.939,-86.939,-90.939,-95.939,-98.939,-93.939,-83.939,-79.939,-36.939003,45.060997,42.060997,32.060997,29.060997,-9.939003,-57.939003,-54.939003,-48.939003,-40.939003,-35.939003,-29.939003,-23.939003,-16.939003,-8.939003,-7.939003,-3.939003,1.060997,0.06099701,0.06099701,3.060997,4.060997,6.060997,7.060997,7.060997,7.060997,2.060997,-2.939003,-3.939003,-7.939003,-11.939003,-17.939003,-21.939003,-25.939003,-24.939003,-28.939003,-40.939003,-45.939003,-46.939003,-54.939003,-65.939,-78.939,-81.939,-84.939,-88.939,-91.939,-94.939,-94.939,-93.939,-91.939,-86.939,-83.939,-82.939,-29.939003,39.060997,45.060997,44.060997,35.060997,-24.939003,-69.939,-66.939,-58.939003,-48.939003,-43.939003,-37.939003,-31.939003,-26.939003,-21.939003,-20.939003,-17.939003,-13.939003,-5.939003,-1.939003,-2.939003,-1.939003,-0.939003,-1.939003,-1.939003,-1.939003,1.060997,0.06099701,-3.939003,-10.939003,-16.939003,-18.939003,-19.939003,-19.939003,-23.939003,-29.939003,-37.939003,-40.939003,-45.939003,-51.939003,-60.939003,-68.939,-71.939,-77.939,-86.939,-63.939003,-40.939003,-24.939003,-40.939003,-64.939,-31.939003,2.060997,37.060997,33.060997,26.060997,24.060997,18.060997,11.060997,12.060997,8.060997,-0.939003,-10.939003,-22.939003,-44.939003,-52.939003,-54.939003,-56.939003,-59.939003,-63.939003,-75.939,-85.939,-85.939,-89.939,-95.939,-94.939,-94.939,-96.939,-96.939,-96.939,-95.939,-91.939,-85.939,-84.939,-84.939,-84.939,-86.939,-88.939,-87.939,-85.939,-83.939,-82.939,-81.939,-82.939,-80.939,-78.939,-78.939,-76.939,-75.939,-79.939,-80.939,-77.939,-79.939,-82.939,-88.939,-85.939,-80.939,-80.939,-81.939,-82.939,-80.939,-80.939,-84.939,-84.939,-84.939,-85.939,-86.939,-86.939,-87.939,-88.939,-88.939,-80.939,-69.939,-67.939,-59.939003,-47.939003,-27.939003,-18.939003,-38.939003,-36.939003,-25.939003,-29.939003,-40.939003,-59.939003,-71.939,-78.939,-72.939,-26.939003,25.060997,-33.939003,-69.939,-83.939,-80.939,-79.939,-83.939,-84.939,-83.939,-81.939,-78.939,-74.939,-71.939,-70.939,-69.939,-69.939,-69.939,-63.939003,-58.939003,-55.939003,-49.939003,-44.939003,-41.939003,-45.939003,-52.939003,-56.939003,-59.939003,-60.939003,-63.939003,-68.939,-71.939,-73.939,-74.939,-76.939,-78.939,-81.939,-83.939,-85.939,-89.939,-87.939,-85.939,-89.939,-91.939,-88.939,-86.939,-85.939,-84.939,-83.939,-82.939,-84.939,-87.939,-91.939,-92.939,-93.939,-96.939,-97.939,-98.939,-97.939,-96.939,-96.939,-97.939,-98.939,-97.939,-96.939,-95.939,-92.939,-89.939,-88.939,-90.939,-91.939,-90.939,-89.939,-89.939,-92.939,-93.939,-90.939,-88.939,-86.939,-85.939,-83.939,-80.939,-82.939,-82.939,-81.939,-81.939,-81.939,-86.939,-85.939,-81.939,-84.939,-85.939,-85.939,-84.939,-84.939,-85.939,-84.939,-84.939,-84.939,-86.939,-89.939,-88.939,-88.939,-89.939,-90.939,-91.939,-91.939,-91.939,-93.939,-89.939,-86.939,-85.939,-84.939,-82.939,-86.939,-89.939,-90.939,-91.939,-92.939,-90.939,-89.939,-89.939,-87.939,-86.939,-85.939,-84.939,-84.939,-86.939,-85.939,-84.939,-83.939,-83.939,-84.939,-84.939,-84.939,-85.939,-88.939,-91.939,-90.939,-90.939,-90.939,-90.939,-90.939,-90.939,-89.939,-87.939,-90.939,-89.939,-85.939,-89.939,-90.939,-89.939,-89.939,-89.939,-93.939,-93.939,-88.939,-90.939,-91.939,-89.939,-89.939,-90.939,-91.939,-90.939,-89.939,-88.939,-87.939,-88.939,-89.939,-89.939,-89.939,-90.939,-91.939,-89.939,-89.939,-90.939,-89.939,-88.939,-90.939,-92.939,-94.939,-90.939,-87.939,-90.939,-90.939,-88.939,-91.939,-91.939,-90.939,-92.939,-95.939,-97.939,-99.939,-101.939,-100.939,-91.939,-75.939,-78.939,-78.939,-67.939,-30.939003,13.060997,18.060997,-2.939003,-53.939003,-66.939,-71.939,-71.939,-70.939,-69.939,-70.939,-66.939,-56.939003,-53.939003,-52.939003,-54.939003,-56.939003,-58.939003,-59.939003,-59.939003,-56.939003,-69.939,-81.939,-90.939,-86.939,-78.939,-81.939,-81.939,-78.939,-90.939,-98.939,-86.939,-71.939,-57.939003,-77.939,-87.939,-86.939,-80.939,-73.939,-68.939,-72.939,-77.939,-76.939,-78.939,-81.939,-73.939,-67.939,-66.939,-72.939,-77.939,-80.939,-82.939,-85.939,-89.939,-91.939,-84.939,-87.939,-95.939,-99.939,-94.939,-80.939,-72.939,-63.939003,-49.939003,-49.939003,-54.939003,-48.939003,-46.939003,-47.939003,-43.939003,-43.939003,-53.939003,-56.939003,-55.939003,-40.939003,-33.939003,-35.939003,-33.939003,-31.939003,-31.939003,-31.939003,-30.939003,-34.939003,-36.939003,-36.939003,-34.939003,-31.939003,-23.939003,-25.939003,-32.939003,-31.939003,-33.939003,-36.939003,-27.939003,-19.939003,67.061,63.060997,54.060997,16.060997,25.060997,83.061,86.061,83.061,82.061,83.061,85.061,84.061,86.061,91.061,83.061,74.061,66.061,73.061,80.061,31.060997,-22.939003,-81.939,-71.939,-58.939003,-59.939003,-56.939003,-51.939003,-43.939003,-33.939003,-20.939003,-13.939003,-7.939003,-2.939003,4.060997,11.060997,14.060997,15.060997,15.060997,23.060997,30.060997,33.060997,30.060997,26.060997,24.060997,21.060997,18.060997,9.060997,2.060997,0.06099701,-5.939003,-15.939003,-23.939003,-32.939003,-42.939003,-54.939003,-63.939003,-65.939,-70.939,-77.939,-85.939,-92.939,-95.939,-96.939,-96.939,-98.939,-98.939,-98.939,-99.939,-100.939,-100.939,-101.939,-101.939,-102.939,-103.939,-103.939,-102.939,-101.939,-100.939,-94.939,-87.939,-80.939,-73.939,-67.939,-68.939,-51.939003,-13.939003,5.060997,11.060997,-6.939003,-11.939003,-9.939003,-8.939003,-2.939003,7.060997,10.060997,13.060997,20.060997,26.060997,31.060997,32.060997,32.060997,31.060997,29.060997,27.060997,26.060997,20.060997,13.060997,3.060997,-6.939003,-13.939003,-19.939003,-23.939003,-27.939003,-38.939003,-53.939003,-64.939,-73.939,-81.939,-84.939,-89.939,-93.939,-96.939,-96.939,-97.939,-98.939,-98.939,-99.939,-100.939,-100.939,-101.939,-101.939,-102.939,-102.939,-103.939,-103.939,-101.939,-96.939,-91.939,-84.939,-78.939,-70.939,-62.939003,-51.939003,-44.939003,-40.939003,-35.939003,-28.939003,-16.939003,-4.939003,3.060997,-9.939003,-21.939003,-28.939003,-33.939003,-34.939003,-19.939003,3.060997,33.060997,36.060997,36.060997,33.060997,31.060997,28.060997,12.060997,-0.939003,-11.939003,-18.939003,-25.939003,-26.939003,-35.939003,-45.939003,-56.939003,-65.939,-74.939,-84.939,-93.939,-94.939,-96.939,-96.939,-97.939,-96.939,-94.939,-96.939,-97.939,-97.939,-98.939,-101.939,-101.939,-102.939,-102.939,-103.939,-102.939,-97.939,-89.939,-79.939,-78.939,-72.939,-63.939003,-54.939003,-46.939003,-38.939003,-30.939003,-21.939003,-13.939003,-6.939003,-2.939003,5.060997,12.060997,19.060997,20.060997,17.060997,5.060997,1.060997,6.060997,-11.939003,-26.939003,-28.939003,-24.939003,-16.939003,-10.939003,-3.939003,3.060997,6.060997,10.060997,20.060997,29.060997,35.060997,13.060997,8.060997,23.060997,21.060997,30.060997,71.061,85.061,86.061,76.061,75.061,82.061,87.061,89.061,88.061,85.061,83.061,79.061,76.061,76.061,72.061,69.061,71.061,51.060997,25.060997,25.060997,25.060997,28.060997,27.060997,25.060997,16.060997,-37.939003,-102.939,-87.939,-83.939,-88.939,-85.939,-82.939,-80.939,-88.939,-96.939,-69.939,-24.939003,40.060997,35.060997,28.060997,32.060997,32.060997,29.060997,25.060997,20.060997,16.060997,10.060997,1.060997,-9.939003,-16.939003,-21.939003,-24.939003,-31.939003,-40.939003,-48.939003,-57.939003,-67.939,-77.939,-85.939,-92.939,-95.939,-95.939,-96.939,-96.939,-97.939,-97.939,-97.939,-98.939,-98.939,-96.939,-98.939,-95.939,-81.939,-65.939,-51.939003,-57.939003,-70.939,-89.939,-93.939,-95.939,-93.939,-72.939,-43.939003,-31.939003,-23.939003,-21.939003,-26.939003,-32.939003,-37.939003,-18.939003,8.060997,20.060997,24.060997,22.060997,23.060997,25.060997,27.060997,29.060997,31.060997,25.060997,19.060997,14.060997,9.060997,3.060997,-9.939003,-18.939003,-24.939003,-28.939003,-37.939003,-47.939003,-56.939003,-65.939,-69.939,-81.939,-92.939,-94.939,-96.939,-96.939,-83.939,-78.939,-93.939,-98.939,-98.939,-96.939,-97.939,-101.939,-101.939,-99.939,-95.939,-86.939,-77.939,-76.939,-73.939,-68.939,-53.939003,-42.939003,-41.939003,-30.939003,-15.939003,-27.939003,-32.939003,-30.939003,-15.939003,-1.939003,4.060997,6.060997,6.060997,11.060997,14.060997,14.060997,20.060997,23.060997,16.060997,13.060997,11.060997,7.060997,-0.939003,-12.939003,-21.939003,-29.939003,-35.939003,-41.939003,-47.939003,-53.939003,-61.939003,-71.939,-80.939,-88.939,-89.939,-91.939,-93.939,-95.939,-96.939,-97.939,-97.939,-98.939,-98.939,-99.939,-100.939,-99.939,-100.939,-101.939,-72.939,-46.939003,-36.939003,-46.939003,-61.939003,-42.939003,-29.939003,-21.939003,-40.939003,-58.939003,-66.939,-75.939,-82.939,-77.939,-79.939,-86.939,-91.939,-93.939,-90.939,-91.939,-92.939,-95.939,-94.939,-90.939,-86.939,-82.939,-81.939,-82.939,-84.939,-87.939,-87.939,-84.939,-82.939,-81.939,-80.939,-78.939,-77.939,-79.939,-81.939,-81.939,-79.939,-77.939,-78.939,-80.939,-81.939,-80.939,-80.939,-81.939,-79.939,-78.939,-79.939,-81.939,-82.939,-83.939,-83.939,-82.939,-82.939,-83.939,-84.939,-84.939,-82.939,-81.939,-80.939,-81.939,-80.939,-81.939,-87.939,-87.939,-85.939,-82.939,-83.939,-88.939,-87.939,-84.939,-80.939,-80.939,-82.939,-78.939,-74.939,-72.939,-67.939,-63.939003,-61.939003,-41.939003,-15.939003,-29.939003,-47.939003,-69.939,-80.939,-83.939,-68.939,-13.939003,41.060997,-50.939003,-96.939,-95.939,-85.939,-74.939,-61.939003,-51.939003,-43.939003,-37.939003,-28.939003,-16.939003,-8.939003,-3.939003,-2.939003,-4.939003,-6.939003,-8.939003,-15.939003,-26.939003,-36.939003,-47.939003,-56.939003,-66.939,-75.939,-83.939,-89.939,-94.939,-96.939,-97.939,-95.939,-91.939,-89.939,-89.939,-89.939,-89.939,-88.939,-87.939,-91.939,-88.939,-84.939,-86.939,-88.939,-90.939,-91.939,-92.939,-90.939,-88.939,-86.939,-85.939,-86.939,-91.939,-90.939,-89.939,-90.939,-90.939,-90.939,-89.939,-87.939,-86.939,-87.939,-87.939,-89.939,-87.939,-85.939,-89.939,-89.939,-84.939,-84.939,-85.939,-89.939,-89.939,-86.939,-84.939,-84.939,-85.939,-84.939,-83.939,-83.939,-84.939,-85.939,-87.939,-87.939,-83.939,-80.939,-79.939,-83.939,-82.939,-80.939,-83.939,-85.939,-86.939,-84.939,-83.939,-83.939,-83.939,-83.939,-82.939,-83.939,-85.939,-84.939,-85.939,-86.939,-87.939,-88.939,-87.939,-88.939,-90.939,-88.939,-86.939,-85.939,-85.939,-84.939,-87.939,-89.939,-91.939,-92.939,-92.939,-88.939,-88.939,-89.939,-86.939,-85.939,-85.939,-85.939,-86.939,-88.939,-87.939,-85.939,-84.939,-83.939,-83.939,-85.939,-86.939,-84.939,-87.939,-90.939,-86.939,-84.939,-82.939,-87.939,-91.939,-89.939,-88.939,-86.939,-89.939,-90.939,-89.939,-90.939,-89.939,-88.939,-88.939,-89.939,-91.939,-90.939,-85.939,-87.939,-88.939,-86.939,-87.939,-89.939,-90.939,-90.939,-90.939,-87.939,-86.939,-87.939,-88.939,-88.939,-89.939,-90.939,-89.939,-88.939,-89.939,-92.939,-92.939,-91.939,-93.939,-94.939,-93.939,-89.939,-87.939,-92.939,-93.939,-91.939,-94.939,-94.939,-89.939,-89.939,-90.939,-92.939,-95.939,-98.939,-98.939,-99.939,-100.939,-99.939,-98.939,-95.939,-54.939003,2.060997,37.060997,12.060997,-75.939,-93.939,-100.939,-100.939,-100.939,-100.939,-100.939,-99.939,-97.939,-98.939,-98.939,-99.939,-99.939,-99.939,-98.939,-98.939,-99.939,-100.939,-98.939,-92.939,-85.939,-80.939,-79.939,-80.939,-81.939,-91.939,-99.939,-101.939,-96.939,-91.939,-97.939,-97.939,-89.939,-80.939,-72.939,-67.939,-69.939,-73.939,-72.939,-73.939,-75.939,-71.939,-71.939,-79.939,-85.939,-88.939,-89.939,-91.939,-94.939,-98.939,-100.939,-93.939,-96.939,-102.939,-103.939,-93.939,-73.939,-67.939,-64.939,-58.939003,-58.939003,-59.939003,-63.939003,-62.939003,-58.939003,-48.939003,-44.939003,-58.939003,-61.939003,-58.939003,-46.939003,-43.939003,-47.939003,-45.939003,-43.939003,-42.939003,-44.939003,-45.939003,-37.939003,-34.939003,-39.939003,-43.939003,-44.939003,-32.939003,-30.939003,-33.939003,-46.939003,-52.939003,-48.939003,-35.939003,-25.939003,72.061,69.061,61.060997,16.060997,15.060997,58.060997,66.061,66.061,56.060997,53.060997,52.060997,47.060997,41.060997,37.060997,29.060997,23.060997,21.060997,19.060997,15.060997,-2.939003,-15.939003,-22.939003,-14.939003,-8.939003,-10.939003,-8.939003,-4.939003,-1.939003,2.060997,5.060997,6.060997,5.060997,1.060997,2.060997,4.060997,1.060997,-4.939003,-13.939003,-15.939003,-16.939003,-20.939003,-26.939003,-33.939003,-37.939003,-42.939003,-48.939003,-53.939003,-50.939003,-27.939003,-23.939003,-29.939003,-57.939003,-74.939,-78.939,-84.939,-89.939,-89.939,-92.939,-95.939,-96.939,-95.939,-93.939,-92.939,-91.939,-86.939,-85.939,-84.939,-81.939,-77.939,-72.939,-66.939,-60.939003,-54.939003,-51.939003,-49.939003,-47.939003,-44.939003,-40.939003,-33.939003,-26.939003,-19.939003,-12.939003,-6.939003,-4.939003,-4.939003,-3.939003,-7.939003,-12.939003,-15.939003,-15.939003,-12.939003,-4.939003,1.060997,5.060997,-0.939003,-5.939003,-8.939003,-12.939003,-16.939003,-19.939003,-25.939003,-33.939003,-38.939003,-43.939003,-44.939003,-48.939003,-51.939003,-56.939003,-61.939003,-65.939,-68.939,-70.939,-71.939,-77.939,-84.939,-88.939,-93.939,-97.939,-97.939,-97.939,-98.939,-95.939,-92.939,-89.939,-86.939,-82.939,-78.939,-75.939,-69.939,-63.939003,-58.939003,-54.939003,-49.939003,-43.939003,-42.939003,-40.939003,-32.939003,-24.939003,-15.939003,-10.939003,-6.939003,-1.939003,2.060997,3.060997,1.060997,3.060997,5.060997,7.060997,7.060997,4.060997,1.060997,-1.939003,-4.939003,3.060997,14.060997,10.060997,-3.939003,-29.939003,-33.939003,-36.939003,-40.939003,-43.939003,-44.939003,-52.939003,-58.939003,-64.939,-67.939,-70.939,-71.939,-75.939,-80.939,-85.939,-89.939,-93.939,-97.939,-99.939,-96.939,-94.939,-92.939,-87.939,-83.939,-79.939,-76.939,-72.939,-67.939,-62.939003,-57.939003,-53.939003,-48.939003,-43.939003,-39.939003,-35.939003,-30.939003,-23.939003,-15.939003,-15.939003,-12.939003,-6.939003,-1.939003,0.06099701,-0.939003,0.06099701,2.060997,3.060997,4.060997,2.060997,2.060997,0.06099701,-6.939003,-8.939003,-7.939003,-10.939003,-9.939003,-6.939003,-0.939003,9.060997,28.060997,33.060997,33.060997,41.060997,47.060997,51.060997,55.060997,58.060997,62.060997,71.061,77.061,32.060997,18.060997,34.060997,30.060997,35.060997,71.061,79.061,74.061,59.060997,54.060997,60.060997,56.060997,51.060997,46.060997,42.060997,38.060997,29.060997,23.060997,19.060997,18.060997,17.060997,12.060997,7.060997,3.060997,-1.939003,-7.939003,-14.939003,-6.939003,-5.939003,-26.939003,-58.939003,-89.939,-83.939,-55.939003,-4.939003,-25.939003,-54.939003,-84.939,-92.939,-89.939,-79.939,-51.939003,-5.939003,-13.939003,-24.939003,-29.939003,-35.939003,-40.939003,-44.939003,-47.939003,-50.939003,-53.939003,-57.939003,-63.939003,-66.939,-68.939,-70.939,-73.939,-78.939,-82.939,-85.939,-90.939,-94.939,-95.939,-92.939,-89.939,-87.939,-80.939,-74.939,-70.939,-66.939,-60.939003,-56.939003,-65.939,-88.939,-94.939,-93.939,-84.939,-30.939003,33.060997,11.060997,-30.939003,-92.939,-92.939,-88.939,-86.939,-60.939003,-25.939003,-3.939003,-1.939003,-20.939003,-20.939003,-18.939003,-14.939003,-13.939003,-13.939003,-12.939003,-15.939003,-20.939003,-25.939003,-29.939003,-31.939003,-36.939003,-42.939003,-45.939003,-48.939003,-51.939003,-53.939003,-56.939003,-62.939003,-67.939,-70.939,-72.939,-75.939,-78.939,-83.939,-87.939,-89.939,-90.939,-89.939,-85.939,-82.939,-79.939,-65.939,-57.939003,-60.939003,-60.939003,-56.939003,-50.939003,-47.939003,-43.939003,-41.939003,-36.939003,-26.939003,-20.939003,-15.939003,-12.939003,-9.939003,-6.939003,-2.939003,-0.939003,-2.939003,-8.939003,-15.939003,-18.939003,-17.939003,-14.939003,-15.939003,-14.939003,-14.939003,-15.939003,-16.939003,-20.939003,-28.939003,-40.939003,-39.939003,-39.939003,-45.939003,-49.939003,-52.939003,-52.939003,-52.939003,-49.939003,-54.939003,-58.939003,-60.939003,-67.939,-77.939,-82.939,-85.939,-85.939,-88.939,-93.939,-97.939,-99.939,-99.939,-97.939,-94.939,-89.939,-90.939,-90.939,-87.939,-85.939,-84.939,-88.939,-91.939,-93.939,-83.939,-74.939,-69.939,-72.939,-76.939,-68.939,-65.939,-65.939,-73.939,-79.939,-80.939,-83.939,-87.939,-85.939,-86.939,-88.939,-86.939,-84.939,-84.939,-85.939,-86.939,-84.939,-84.939,-85.939,-82.939,-80.939,-80.939,-79.939,-79.939,-82.939,-83.939,-81.939,-81.939,-81.939,-79.939,-77.939,-76.939,-79.939,-80.939,-78.939,-79.939,-79.939,-80.939,-81.939,-82.939,-82.939,-81.939,-81.939,-80.939,-79.939,-79.939,-80.939,-82.939,-81.939,-82.939,-84.939,-82.939,-81.939,-79.939,-80.939,-81.939,-80.939,-79.939,-80.939,-80.939,-80.939,-82.939,-84.939,-87.939,-83.939,-82.939,-86.939,-86.939,-85.939,-80.939,-81.939,-83.939,-78.939,-77.939,-78.939,-78.939,-78.939,-74.939,-63.939003,-50.939003,-58.939003,-63.939003,-65.939,-78.939,-83.939,-64.939,-30.939003,2.060997,-27.939003,-42.939003,-39.939003,-34.939003,-30.939003,-28.939003,-26.939003,-25.939003,-30.939003,-34.939003,-36.939003,-37.939003,-37.939003,-36.939003,-40.939003,-46.939003,-50.939003,-56.939003,-63.939003,-68.939,-73.939,-77.939,-81.939,-85.939,-88.939,-88.939,-89.939,-89.939,-90.939,-88.939,-87.939,-86.939,-87.939,-87.939,-87.939,-85.939,-84.939,-88.939,-87.939,-86.939,-87.939,-88.939,-87.939,-88.939,-88.939,-85.939,-84.939,-85.939,-83.939,-84.939,-89.939,-87.939,-85.939,-85.939,-86.939,-87.939,-87.939,-85.939,-84.939,-85.939,-86.939,-86.939,-86.939,-87.939,-88.939,-88.939,-87.939,-87.939,-86.939,-86.939,-86.939,-85.939,-86.939,-86.939,-85.939,-85.939,-84.939,-83.939,-83.939,-83.939,-86.939,-86.939,-85.939,-82.939,-81.939,-84.939,-83.939,-82.939,-83.939,-84.939,-85.939,-83.939,-82.939,-82.939,-84.939,-84.939,-82.939,-81.939,-82.939,-83.939,-85.939,-87.939,-87.939,-87.939,-87.939,-87.939,-88.939,-88.939,-88.939,-87.939,-86.939,-86.939,-88.939,-88.939,-89.939,-90.939,-90.939,-87.939,-87.939,-87.939,-85.939,-84.939,-85.939,-86.939,-87.939,-89.939,-89.939,-87.939,-85.939,-84.939,-84.939,-84.939,-84.939,-82.939,-84.939,-85.939,-85.939,-86.939,-86.939,-89.939,-90.939,-88.939,-89.939,-89.939,-89.939,-90.939,-91.939,-90.939,-89.939,-89.939,-89.939,-88.939,-90.939,-88.939,-84.939,-87.939,-89.939,-87.939,-88.939,-90.939,-88.939,-88.939,-88.939,-87.939,-87.939,-89.939,-89.939,-88.939,-89.939,-89.939,-88.939,-88.939,-89.939,-93.939,-91.939,-89.939,-90.939,-92.939,-92.939,-90.939,-89.939,-92.939,-92.939,-90.939,-93.939,-94.939,-91.939,-90.939,-91.939,-92.939,-95.939,-99.939,-98.939,-97.939,-94.939,-95.939,-97.939,-99.939,-81.939,-57.939003,-40.939003,-48.939003,-81.939,-86.939,-85.939,-79.939,-76.939,-75.939,-74.939,-72.939,-69.939,-69.939,-69.939,-71.939,-74.939,-77.939,-76.939,-76.939,-79.939,-86.939,-89.939,-84.939,-82.939,-81.939,-78.939,-82.939,-91.939,-95.939,-98.939,-102.939,-98.939,-93.939,-93.939,-90.939,-83.939,-80.939,-77.939,-73.939,-70.939,-67.939,-70.939,-70.939,-66.939,-66.939,-68.939,-73.939,-76.939,-78.939,-82.939,-85.939,-89.939,-90.939,-90.939,-87.939,-89.939,-91.939,-90.939,-87.939,-82.939,-82.939,-81.939,-76.939,-75.939,-74.939,-75.939,-72.939,-69.939,-58.939003,-53.939003,-63.939003,-66.939,-65.939,-57.939003,-54.939003,-53.939003,-53.939003,-54.939003,-55.939003,-54.939003,-51.939003,-45.939003,-45.939003,-50.939003,-52.939003,-53.939003,-46.939003,-47.939003,-50.939003,-59.939003,-60.939003,-55.939003,-47.939003,-41.939003,57.060997,55.060997,48.060997,10.060997,3.060997,25.060997,32.060997,32.060997,21.060997,16.060997,14.060997,9.060997,1.060997,-8.939003,-13.939003,-14.939003,-12.939003,-18.939003,-27.939003,-21.939003,-9.939003,11.060997,13.060997,13.060997,10.060997,10.060997,10.060997,8.060997,5.060997,0.06099701,-2.939003,-7.939003,-19.939003,-23.939003,-23.939003,-30.939003,-39.939003,-51.939003,-59.939003,-66.939,-73.939,-79.939,-86.939,-90.939,-95.939,-101.939,-102.939,-90.939,-47.939003,-34.939003,-36.939003,-77.939,-96.939,-95.939,-94.939,-93.939,-91.939,-90.939,-88.939,-81.939,-75.939,-70.939,-67.939,-64.939,-55.939003,-52.939003,-50.939003,-45.939003,-39.939003,-33.939003,-23.939003,-14.939003,-7.939003,-2.939003,-0.939003,-0.939003,2.060997,5.060997,10.060997,14.060997,16.060997,20.060997,24.060997,25.060997,17.060997,3.060997,-4.939003,-7.939003,-2.939003,2.060997,6.060997,-8.939003,-17.939003,-20.939003,-30.939003,-40.939003,-49.939003,-57.939003,-66.939,-72.939,-79.939,-89.939,-95.939,-100.939,-100.939,-101.939,-101.939,-100.939,-99.939,-98.939,-96.939,-94.939,-93.939,-92.939,-90.939,-88.939,-87.939,-87.939,-83.939,-78.939,-74.939,-68.939,-62.939003,-57.939003,-51.939003,-45.939003,-40.939003,-36.939003,-27.939003,-19.939003,-11.939003,-6.939003,-0.939003,5.060997,5.060997,6.060997,12.060997,18.060997,23.060997,25.060997,25.060997,24.060997,21.060997,16.060997,7.060997,5.060997,2.060997,-1.939003,-9.939003,-19.939003,1.060997,22.060997,36.060997,53.060997,67.061,48.060997,-1.939003,-84.939,-94.939,-97.939,-101.939,-102.939,-101.939,-100.939,-99.939,-99.939,-97.939,-95.939,-93.939,-91.939,-90.939,-89.939,-88.939,-87.939,-83.939,-77.939,-71.939,-67.939,-64.939,-55.939003,-49.939003,-44.939003,-39.939003,-33.939003,-26.939003,-18.939003,-11.939003,-6.939003,-0.939003,4.060997,9.060997,12.060997,14.060997,16.060997,19.060997,17.060997,17.060997,17.060997,16.060997,12.060997,5.060997,-1.939003,-5.939003,-9.939003,-13.939003,-18.939003,-23.939003,-31.939003,-45.939003,-48.939003,-44.939003,-28.939003,-19.939003,-17.939003,8.060997,39.060997,74.061,77.061,70.061,73.061,76.061,77.061,78.061,78.061,77.061,83.061,86.061,36.060997,18.060997,31.060997,25.060997,26.060997,48.060997,50.060997,42.060997,28.060997,23.060997,27.060997,18.060997,9.060997,4.060997,0.06099701,-3.939003,-11.939003,-18.939003,-23.939003,-20.939003,-19.939003,-22.939003,-17.939003,-8.939003,-13.939003,-19.939003,-25.939003,-21.939003,-28.939003,-60.939003,-73.939,-78.939,-78.939,-36.939003,49.060997,19.060997,-27.939003,-81.939,-91.939,-83.939,-87.939,-79.939,-56.939003,-65.939,-76.939,-85.939,-92.939,-97.939,-100.939,-100.939,-99.939,-98.939,-96.939,-93.939,-92.939,-91.939,-89.939,-88.939,-86.939,-84.939,-83.939,-81.939,-79.939,-75.939,-65.939,-58.939003,-53.939003,-45.939003,-37.939003,-32.939003,-23.939003,-15.939003,-18.939003,-41.939003,-83.939,-89.939,-89.939,-85.939,-12.939003,77.061,48.060997,-5.939003,-87.939,-88.939,-81.939,-77.939,-56.939003,-29.939003,-7.939003,4.060997,3.060997,8.060997,13.060997,25.060997,-3.939003,-46.939003,-53.939003,-61.939003,-66.939,-74.939,-81.939,-85.939,-92.939,-98.939,-98.939,-97.939,-94.939,-94.939,-93.939,-92.939,-90.939,-88.939,-86.939,-83.939,-79.939,-80.939,-79.939,-77.939,-69.939,-60.939003,-53.939003,-47.939003,-43.939003,-33.939003,-26.939003,-22.939003,-18.939003,-12.939003,-8.939003,-4.939003,1.060997,4.060997,8.060997,16.060997,18.060997,16.060997,18.060997,19.060997,20.060997,12.060997,5.060997,0.06099701,-1.939003,-2.939003,-0.939003,2.060997,4.060997,-13.939003,-28.939003,-31.939003,-32.939003,-33.939003,-41.939003,-56.939003,-76.939,-87.939,-93.939,-94.939,-97.939,-99.939,-92.939,-82.939,-67.939,-69.939,-71.939,-70.939,-78.939,-90.939,-95.939,-94.939,-88.939,-86.939,-88.939,-94.939,-97.939,-96.939,-91.939,-85.939,-78.939,-80.939,-80.939,-75.939,-72.939,-71.939,-79.939,-82.939,-83.939,-89.939,-92.939,-90.939,-88.939,-86.939,-86.939,-87.939,-90.939,-90.939,-87.939,-82.939,-81.939,-83.939,-84.939,-84.939,-83.939,-77.939,-73.939,-75.939,-77.939,-79.939,-74.939,-74.939,-78.939,-78.939,-78.939,-77.939,-77.939,-77.939,-79.939,-80.939,-80.939,-80.939,-80.939,-79.939,-77.939,-76.939,-80.939,-80.939,-76.939,-79.939,-81.939,-81.939,-81.939,-82.939,-82.939,-82.939,-81.939,-80.939,-80.939,-79.939,-80.939,-81.939,-80.939,-81.939,-83.939,-82.939,-81.939,-77.939,-78.939,-80.939,-79.939,-79.939,-80.939,-80.939,-80.939,-78.939,-81.939,-87.939,-83.939,-82.939,-85.939,-86.939,-85.939,-82.939,-81.939,-81.939,-78.939,-78.939,-81.939,-84.939,-87.939,-83.939,-81.939,-80.939,-81.939,-76.939,-66.939,-78.939,-82.939,-55.939003,-39.939003,-27.939003,-13.939003,-7.939003,-7.939003,-7.939003,-10.939003,-20.939003,-25.939003,-30.939003,-40.939003,-50.939003,-62.939003,-67.939,-71.939,-70.939,-74.939,-81.939,-84.939,-87.939,-89.939,-89.939,-88.939,-88.939,-87.939,-87.939,-86.939,-83.939,-81.939,-81.939,-81.939,-81.939,-81.939,-82.939,-83.939,-84.939,-84.939,-83.939,-83.939,-86.939,-87.939,-87.939,-88.939,-88.939,-85.939,-85.939,-84.939,-80.939,-82.939,-85.939,-83.939,-83.939,-86.939,-83.939,-82.939,-82.939,-83.939,-84.939,-86.939,-85.939,-84.939,-85.939,-86.939,-83.939,-85.939,-88.939,-87.939,-88.939,-89.939,-90.939,-89.939,-85.939,-84.939,-85.939,-89.939,-89.939,-86.939,-86.939,-86.939,-84.939,-83.939,-82.939,-84.939,-85.939,-86.939,-84.939,-82.939,-84.939,-84.939,-84.939,-84.939,-85.939,-86.939,-83.939,-81.939,-82.939,-84.939,-85.939,-82.939,-80.939,-79.939,-82.939,-84.939,-86.939,-86.939,-86.939,-87.939,-87.939,-86.939,-88.939,-89.939,-88.939,-87.939,-88.939,-88.939,-88.939,-87.939,-88.939,-88.939,-86.939,-85.939,-84.939,-84.939,-84.939,-85.939,-86.939,-89.939,-91.939,-91.939,-91.939,-87.939,-86.939,-85.939,-83.939,-81.939,-79.939,-80.939,-81.939,-84.939,-87.939,-89.939,-88.939,-88.939,-87.939,-90.939,-92.939,-88.939,-89.939,-92.939,-89.939,-89.939,-90.939,-90.939,-87.939,-89.939,-87.939,-83.939,-87.939,-90.939,-88.939,-88.939,-90.939,-87.939,-85.939,-86.939,-87.939,-88.939,-91.939,-90.939,-88.939,-89.939,-89.939,-88.939,-88.939,-89.939,-93.939,-91.939,-87.939,-88.939,-90.939,-92.939,-90.939,-90.939,-91.939,-91.939,-90.939,-93.939,-94.939,-93.939,-92.939,-92.939,-92.939,-94.939,-97.939,-96.939,-92.939,-86.939,-89.939,-94.939,-97.939,-99.939,-101.939,-101.939,-97.939,-86.939,-81.939,-76.939,-67.939,-63.939003,-61.939003,-60.939003,-57.939003,-51.939003,-48.939003,-46.939003,-48.939003,-50.939003,-53.939003,-52.939003,-53.939003,-56.939003,-69.939,-77.939,-79.939,-80.939,-80.939,-78.939,-83.939,-96.939,-98.939,-98.939,-99.939,-91.939,-80.939,-78.939,-74.939,-69.939,-74.939,-78.939,-76.939,-71.939,-65.939,-68.939,-66.939,-60.939003,-63.939003,-65.939,-65.939,-66.939,-68.939,-74.939,-79.939,-83.939,-80.939,-78.939,-80.939,-80.939,-80.939,-79.939,-83.939,-91.939,-94.939,-94.939,-90.939,-88.939,-86.939,-84.939,-81.939,-79.939,-70.939,-65.939,-70.939,-73.939,-72.939,-67.939,-64.939,-61.939003,-62.939003,-64.939,-67.939,-64.939,-58.939003,-57.939003,-58.939003,-61.939003,-62.939003,-62.939003,-60.939003,-60.939003,-62.939003,-64.939,-61.939003,-55.939003,-57.939003,-59.939003,3.060997,0.06099701,-5.939003,-9.939003,-11.939003,-12.939003,-19.939003,-25.939003,-25.939003,-24.939003,-21.939003,-14.939003,-12.939003,-13.939003,-12.939003,-9.939003,-8.939003,-7.939003,-5.939003,-0.939003,-6.939003,-24.939003,-31.939003,-35.939003,-37.939003,-43.939003,-51.939003,-58.939003,-64.939,-69.939,-72.939,-76.939,-82.939,-84.939,-85.939,-86.939,-88.939,-91.939,-93.939,-94.939,-96.939,-98.939,-99.939,-98.939,-98.939,-98.939,-101.939,-89.939,-40.939003,-25.939003,-28.939003,-58.939003,-71.939,-67.939,-62.939003,-58.939003,-52.939003,-46.939003,-40.939003,-24.939003,-17.939003,-20.939003,-12.939003,-3.939003,3.060997,8.060997,13.060997,17.060997,18.060997,16.060997,23.060997,26.060997,22.060997,22.060997,20.060997,14.060997,9.060997,6.060997,2.060997,-5.939003,-16.939003,-20.939003,-22.939003,-35.939003,-30.939003,-8.939003,30.060997,55.060997,47.060997,57.060997,65.061,-19.939003,-68.939,-82.939,-86.939,-88.939,-90.939,-92.939,-95.939,-96.939,-98.939,-100.939,-100.939,-99.939,-95.939,-94.939,-92.939,-92.939,-88.939,-81.939,-71.939,-64.939,-60.939003,-55.939003,-49.939003,-41.939003,-36.939003,-35.939003,-25.939003,-13.939003,-2.939003,2.060997,5.060997,11.060997,16.060997,22.060997,20.060997,19.060997,20.060997,25.060997,30.060997,27.060997,20.060997,10.060997,8.060997,4.060997,-5.939003,-13.939003,-20.939003,-26.939003,-31.939003,-36.939003,-47.939003,-57.939003,-70.939,-76.939,-78.939,-80.939,-81.939,-84.939,-26.939003,35.060997,91.061,106.061,101.061,86.061,23.060997,-89.939,-101.939,-102.939,-103.939,-101.939,-96.939,-91.939,-87.939,-84.939,-75.939,-66.939,-60.939003,-53.939003,-46.939003,-40.939003,-36.939003,-34.939003,-22.939003,-12.939003,-3.939003,-0.939003,2.060997,9.060997,12.060997,14.060997,18.060997,21.060997,24.060997,24.060997,24.060997,22.060997,17.060997,9.060997,4.060997,-1.939003,-8.939003,-17.939003,-25.939003,-28.939003,-33.939003,-41.939003,-51.939003,-60.939003,-66.939,-72.939,-79.939,-80.939,-81.939,-83.939,-84.939,-86.939,-90.939,-90.939,-85.939,-43.939003,-19.939003,-13.939003,10.060997,39.060997,74.061,75.061,61.060997,52.060997,45.060997,42.060997,34.060997,29.060997,28.060997,27.060997,23.060997,1.060997,-6.939003,-2.939003,-5.939003,-8.939003,-9.939003,-13.939003,-15.939003,-15.939003,-14.939003,-11.939003,-16.939003,-19.939003,-16.939003,-17.939003,-18.939003,-16.939003,-15.939003,-14.939003,-8.939003,-1.939003,8.060997,10.060997,8.060997,7.060997,17.060997,35.060997,-0.939003,-38.939003,-70.939,-77.939,-73.939,-73.939,-46.939003,8.060997,7.060997,-11.939003,-61.939003,-79.939,-84.939,-89.939,-91.939,-91.939,-94.939,-96.939,-96.939,-96.939,-96.939,-95.939,-92.939,-84.939,-79.939,-73.939,-61.939003,-54.939003,-49.939003,-44.939003,-36.939003,-28.939003,-22.939003,-16.939003,-10.939003,-5.939003,-1.939003,6.060997,12.060997,17.060997,15.060997,13.060997,11.060997,21.060997,29.060997,-12.939003,-53.939003,-94.939,-88.939,-81.939,-79.939,-43.939003,2.060997,-3.939003,-25.939003,-65.939,-75.939,-76.939,-65.939,-71.939,-81.939,-79.939,-31.939003,61.060997,69.061,72.061,84.061,14.060997,-78.939,-88.939,-93.939,-95.939,-96.939,-98.939,-97.939,-91.939,-83.939,-82.939,-76.939,-66.939,-62.939003,-58.939003,-54.939003,-47.939003,-38.939003,-30.939003,-23.939003,-17.939003,-12.939003,-8.939003,-3.939003,3.060997,11.060997,10.060997,11.060997,14.060997,11.060997,8.060997,5.060997,8.060997,11.060997,3.060997,0.06099701,-0.939003,-3.939003,-9.939003,-18.939003,-24.939003,-28.939003,-38.939003,-45.939003,-48.939003,-61.939003,-73.939,-77.939,-34.939003,22.060997,15.060997,10.060997,7.060997,-8.939003,-21.939003,-22.939003,-21.939003,-18.939003,-16.939003,-24.939003,-42.939003,-75.939,-97.939,-89.939,-86.939,-85.939,-65.939,-47.939003,-32.939003,-37.939003,-43.939003,-47.939003,-57.939003,-68.939,-73.939,-76.939,-76.939,-77.939,-78.939,-82.939,-85.939,-88.939,-83.939,-80.939,-79.939,-79.939,-78.939,-77.939,-77.939,-78.939,-81.939,-81.939,-77.939,-80.939,-83.939,-79.939,-77.939,-78.939,-78.939,-77.939,-75.939,-78.939,-79.939,-74.939,-76.939,-82.939,-79.939,-79.939,-82.939,-78.939,-76.939,-75.939,-79.939,-82.939,-77.939,-75.939,-77.939,-77.939,-76.939,-72.939,-76.939,-82.939,-79.939,-78.939,-78.939,-78.939,-78.939,-77.939,-77.939,-77.939,-81.939,-82.939,-77.939,-77.939,-79.939,-80.939,-80.939,-79.939,-81.939,-81.939,-81.939,-79.939,-78.939,-79.939,-78.939,-77.939,-80.939,-82.939,-82.939,-82.939,-82.939,-81.939,-81.939,-82.939,-78.939,-78.939,-82.939,-81.939,-80.939,-78.939,-80.939,-82.939,-81.939,-83.939,-87.939,-84.939,-83.939,-83.939,-80.939,-76.939,-77.939,-80.939,-85.939,-86.939,-85.939,-84.939,-86.939,-87.939,-83.939,-82.939,-83.939,-87.939,-77.939,-35.939003,-21.939003,-22.939003,-32.939003,-41.939003,-49.939003,-55.939003,-61.939003,-71.939,-76.939,-79.939,-78.939,-78.939,-78.939,-79.939,-79.939,-79.939,-79.939,-80.939,-77.939,-75.939,-71.939,-73.939,-73.939,-74.939,-72.939,-71.939,-74.939,-76.939,-76.939,-78.939,-80.939,-80.939,-80.939,-79.939,-78.939,-80.939,-83.939,-85.939,-85.939,-85.939,-86.939,-88.939,-89.939,-88.939,-85.939,-84.939,-83.939,-81.939,-84.939,-89.939,-86.939,-85.939,-84.939,-82.939,-80.939,-80.939,-81.939,-83.939,-85.939,-86.939,-85.939,-86.939,-85.939,-81.939,-82.939,-84.939,-86.939,-88.939,-91.939,-92.939,-92.939,-90.939,-89.939,-88.939,-89.939,-89.939,-87.939,-89.939,-90.939,-89.939,-87.939,-84.939,-84.939,-84.939,-86.939,-85.939,-84.939,-83.939,-86.939,-88.939,-86.939,-86.939,-87.939,-84.939,-82.939,-83.939,-85.939,-86.939,-81.939,-79.939,-78.939,-80.939,-83.939,-84.939,-85.939,-86.939,-88.939,-88.939,-85.939,-87.939,-88.939,-87.939,-86.939,-87.939,-88.939,-86.939,-84.939,-84.939,-84.939,-84.939,-83.939,-82.939,-83.939,-84.939,-85.939,-88.939,-91.939,-94.939,-96.939,-96.939,-92.939,-89.939,-85.939,-82.939,-79.939,-76.939,-75.939,-76.939,-79.939,-82.939,-84.939,-83.939,-83.939,-86.939,-89.939,-91.939,-87.939,-86.939,-90.939,-88.939,-87.939,-91.939,-88.939,-85.939,-88.939,-88.939,-84.939,-88.939,-90.939,-88.939,-87.939,-88.939,-85.939,-84.939,-85.939,-85.939,-87.939,-91.939,-90.939,-87.939,-88.939,-89.939,-90.939,-90.939,-91.939,-93.939,-90.939,-86.939,-88.939,-89.939,-90.939,-91.939,-90.939,-87.939,-88.939,-89.939,-92.939,-94.939,-95.939,-96.939,-95.939,-94.939,-91.939,-89.939,-87.939,-85.939,-83.939,-86.939,-90.939,-92.939,-94.939,-96.939,-96.939,-95.939,-91.939,-91.939,-91.939,-92.939,-91.939,-90.939,-89.939,-82.939,-71.939,-67.939,-61.939003,-50.939003,-44.939003,-40.939003,-42.939003,-42.939003,-41.939003,-54.939003,-69.939,-84.939,-83.939,-77.939,-77.939,-81.939,-88.939,-94.939,-96.939,-90.939,-68.939,-42.939003,-39.939003,-39.939003,-41.939003,-56.939003,-69.939,-75.939,-75.939,-74.939,-67.939,-63.939003,-62.939003,-65.939,-66.939,-59.939003,-59.939003,-62.939003,-71.939,-76.939,-78.939,-71.939,-67.939,-71.939,-75.939,-77.939,-78.939,-85.939,-98.939,-98.939,-96.939,-92.939,-90.939,-88.939,-90.939,-90.939,-88.939,-84.939,-81.939,-82.939,-81.939,-80.939,-75.939,-73.939,-73.939,-73.939,-74.939,-76.939,-73.939,-69.939,-74.939,-74.939,-71.939,-70.939,-70.939,-69.939,-64.939,-57.939003,-52.939003,-47.939003,-44.939003,-64.939,-80.939,-10.939003,-13.939003,-16.939003,-11.939003,-8.939003,-6.939003,-8.939003,-8.939003,-3.939003,0.06099701,4.060997,12.060997,18.060997,22.060997,27.060997,32.060997,32.060997,36.060997,41.060997,42.060997,13.060997,-45.939003,-63.939003,-73.939,-74.939,-78.939,-84.939,-89.939,-92.939,-93.939,-94.939,-94.939,-95.939,-95.939,-93.939,-90.939,-86.939,-82.939,-82.939,-81.939,-79.939,-75.939,-71.939,-66.939,-62.939003,-56.939003,-56.939003,-49.939003,-22.939003,-14.939003,-14.939003,-23.939003,-23.939003,-16.939003,-13.939003,-10.939003,-4.939003,-2.939003,-1.939003,6.060997,8.060997,4.060997,7.060997,9.060997,7.060997,7.060997,7.060997,3.060997,-1.939003,-7.939003,-7.939003,-11.939003,-22.939003,-26.939003,-28.939003,-34.939003,-38.939003,-41.939003,-45.939003,-51.939003,-60.939003,-63.939003,-65.939,-75.939,-62.939003,-29.939003,43.060997,88.061,65.061,73.061,83.061,-11.939003,-70.939,-95.939,-92.939,-87.939,-85.939,-82.939,-78.939,-75.939,-69.939,-62.939003,-57.939003,-53.939003,-51.939003,-46.939003,-40.939003,-38.939003,-34.939003,-28.939003,-18.939003,-11.939003,-8.939003,-4.939003,-0.939003,1.060997,2.060997,1.060997,4.060997,7.060997,8.060997,7.060997,7.060997,6.060997,4.060997,1.060997,-3.939003,-8.939003,-14.939003,-18.939003,-20.939003,-24.939003,-29.939003,-38.939003,-39.939003,-43.939003,-51.939003,-57.939003,-64.939,-68.939,-72.939,-76.939,-83.939,-89.939,-97.939,-99.939,-99.939,-93.939,-92.939,-96.939,-43.939003,14.060997,71.061,80.061,66.061,56.060997,15.060997,-57.939003,-61.939003,-59.939003,-55.939003,-51.939003,-45.939003,-38.939003,-33.939003,-28.939003,-21.939003,-15.939003,-11.939003,-7.939003,-5.939003,0.06099701,1.060997,-0.939003,2.060997,4.060997,3.060997,2.060997,0.06099701,0.06099701,-1.939003,-6.939003,-9.939003,-13.939003,-18.939003,-22.939003,-25.939003,-27.939003,-32.939003,-39.939003,-43.939003,-48.939003,-53.939003,-60.939003,-67.939,-69.939,-72.939,-78.939,-85.939,-91.939,-94.939,-96.939,-98.939,-96.939,-92.939,-86.939,-85.939,-83.939,-79.939,-75.939,-67.939,-34.939003,-14.939003,-8.939003,2.060997,14.060997,27.060997,24.060997,15.060997,10.060997,6.060997,2.060997,-3.939003,-7.939003,-8.939003,-9.939003,-9.939003,-10.939003,-9.939003,-8.939003,-7.939003,-8.939003,-10.939003,-9.939003,-7.939003,-6.939003,-3.939003,1.060997,4.060997,8.060997,14.060997,14.060997,13.060997,19.060997,24.060997,27.060997,35.060997,43.060997,53.060997,48.060997,36.060997,30.060997,46.060997,82.061,14.060997,-50.939003,-82.939,-83.939,-70.939,-44.939003,-20.939003,3.060997,27.060997,33.060997,-4.939003,-49.939003,-92.939,-92.939,-91.939,-92.939,-71.939,-52.939003,-48.939003,-45.939003,-42.939003,-38.939003,-33.939003,-25.939003,-25.939003,-22.939003,-13.939003,-9.939003,-6.939003,-4.939003,-0.939003,2.060997,3.060997,3.060997,5.060997,5.060997,3.060997,2.060997,-0.939003,-3.939003,-9.939003,-15.939003,-24.939003,-21.939003,-20.939003,-49.939003,-73.939,-92.939,-84.939,-67.939,-35.939003,-15.939003,-0.939003,28.060997,21.060997,-22.939003,-62.939003,-85.939,-72.939,-76.939,-87.939,-93.939,-46.939003,50.060997,55.060997,51.060997,52.060997,3.060997,-59.939003,-59.939003,-58.939003,-56.939003,-53.939003,-50.939003,-48.939003,-40.939003,-30.939003,-27.939003,-23.939003,-17.939003,-14.939003,-10.939003,-7.939003,-5.939003,-4.939003,-1.939003,1.060997,5.060997,3.060997,1.060997,0.06099701,1.060997,1.060997,-8.939003,-14.939003,-18.939003,-13.939003,-13.939003,-25.939003,-30.939003,-34.939003,-37.939003,-40.939003,-45.939003,-48.939003,-53.939003,-61.939003,-65.939,-67.939,-67.939,-69.939,-71.939,-80.939,-88.939,-90.939,-49.939003,4.060997,-5.939003,-13.939003,-18.939003,-29.939003,-37.939003,-36.939003,-36.939003,-37.939003,-37.939003,-42.939003,-53.939003,-73.939,-86.939,-82.939,-81.939,-80.939,-70.939,-61.939003,-55.939003,-58.939003,-61.939003,-61.939003,-67.939,-73.939,-75.939,-77.939,-78.939,-78.939,-78.939,-79.939,-80.939,-81.939,-77.939,-76.939,-79.939,-79.939,-78.939,-76.939,-77.939,-79.939,-80.939,-80.939,-78.939,-79.939,-80.939,-76.939,-74.939,-75.939,-74.939,-74.939,-73.939,-75.939,-75.939,-73.939,-76.939,-81.939,-78.939,-79.939,-83.939,-80.939,-78.939,-77.939,-78.939,-79.939,-77.939,-75.939,-74.939,-73.939,-72.939,-74.939,-78.939,-81.939,-78.939,-77.939,-78.939,-78.939,-77.939,-76.939,-77.939,-78.939,-82.939,-81.939,-75.939,-76.939,-78.939,-79.939,-79.939,-78.939,-80.939,-81.939,-81.939,-79.939,-78.939,-79.939,-79.939,-77.939,-79.939,-80.939,-81.939,-82.939,-83.939,-82.939,-82.939,-82.939,-78.939,-78.939,-81.939,-79.939,-78.939,-77.939,-78.939,-80.939,-80.939,-82.939,-87.939,-84.939,-83.939,-84.939,-81.939,-77.939,-79.939,-82.939,-86.939,-86.939,-85.939,-84.939,-84.939,-84.939,-84.939,-85.939,-88.939,-90.939,-83.939,-54.939003,-47.939003,-50.939003,-61.939003,-68.939,-74.939,-75.939,-76.939,-84.939,-86.939,-85.939,-81.939,-80.939,-82.939,-81.939,-80.939,-78.939,-78.939,-77.939,-77.939,-75.939,-71.939,-71.939,-71.939,-71.939,-69.939,-69.939,-70.939,-71.939,-72.939,-74.939,-76.939,-77.939,-78.939,-79.939,-79.939,-79.939,-81.939,-81.939,-82.939,-84.939,-86.939,-87.939,-86.939,-86.939,-86.939,-84.939,-83.939,-83.939,-85.939,-87.939,-84.939,-82.939,-80.939,-80.939,-80.939,-80.939,-80.939,-81.939,-82.939,-82.939,-82.939,-84.939,-84.939,-81.939,-81.939,-83.939,-85.939,-86.939,-87.939,-88.939,-90.939,-91.939,-91.939,-89.939,-89.939,-89.939,-90.939,-89.939,-88.939,-89.939,-88.939,-86.939,-84.939,-84.939,-85.939,-83.939,-82.939,-81.939,-84.939,-87.939,-87.939,-88.939,-89.939,-86.939,-83.939,-83.939,-83.939,-84.939,-81.939,-79.939,-78.939,-80.939,-81.939,-82.939,-84.939,-86.939,-87.939,-86.939,-85.939,-86.939,-87.939,-87.939,-86.939,-86.939,-87.939,-86.939,-83.939,-83.939,-82.939,-83.939,-81.939,-79.939,-82.939,-84.939,-84.939,-87.939,-92.939,-94.939,-95.939,-94.939,-93.939,-89.939,-84.939,-83.939,-81.939,-77.939,-77.939,-78.939,-79.939,-80.939,-81.939,-81.939,-82.939,-84.939,-86.939,-88.939,-86.939,-86.939,-89.939,-87.939,-87.939,-90.939,-88.939,-86.939,-88.939,-88.939,-84.939,-87.939,-88.939,-88.939,-87.939,-87.939,-85.939,-83.939,-84.939,-85.939,-87.939,-89.939,-88.939,-86.939,-87.939,-89.939,-90.939,-90.939,-91.939,-92.939,-90.939,-87.939,-89.939,-89.939,-90.939,-90.939,-89.939,-85.939,-87.939,-90.939,-92.939,-93.939,-95.939,-94.939,-93.939,-90.939,-88.939,-88.939,-85.939,-84.939,-85.939,-85.939,-88.939,-91.939,-93.939,-94.939,-93.939,-93.939,-93.939,-94.939,-95.939,-96.939,-96.939,-96.939,-98.939,-95.939,-88.939,-86.939,-82.939,-74.939,-69.939,-65.939,-67.939,-67.939,-65.939,-73.939,-79.939,-79.939,-77.939,-76.939,-77.939,-80.939,-85.939,-89.939,-93.939,-93.939,-79.939,-62.939003,-59.939003,-57.939003,-56.939003,-69.939,-78.939,-74.939,-72.939,-70.939,-66.939,-64.939,-63.939003,-65.939,-65.939,-59.939003,-60.939003,-63.939003,-70.939,-75.939,-76.939,-71.939,-69.939,-75.939,-75.939,-73.939,-79.939,-86.939,-96.939,-92.939,-89.939,-89.939,-84.939,-78.939,-82.939,-85.939,-85.939,-84.939,-83.939,-85.939,-84.939,-80.939,-70.939,-66.939,-68.939,-69.939,-72.939,-73.939,-71.939,-69.939,-74.939,-75.939,-72.939,-73.939,-73.939,-73.939,-68.939,-61.939003,-51.939003,-46.939003,-48.939003,-69.939,-84.939,-3.939003,-4.939003,-4.939003,-3.939003,2.060997,15.060997,29.060997,42.060997,46.060997,51.060997,54.060997,57.060997,64.061,72.061,79.061,84.061,83.061,87.061,92.061,87.061,38.060997,-55.939003,-85.939,-101.939,-100.939,-99.939,-99.939,-99.939,-96.939,-93.939,-90.939,-86.939,-81.939,-78.939,-75.939,-68.939,-59.939003,-50.939003,-48.939003,-47.939003,-43.939003,-35.939003,-25.939003,-19.939003,-12.939003,-3.939003,0.06099701,0.06099701,-1.939003,-2.939003,-0.939003,12.060997,24.060997,33.060997,31.060997,30.060997,35.060997,29.060997,21.060997,16.060997,11.060997,9.060997,3.060997,-4.939003,-16.939003,-23.939003,-29.939003,-43.939003,-53.939003,-60.939003,-68.939,-79.939,-92.939,-97.939,-100.939,-101.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-100.939,-99.939,-82.939,-49.939003,42.060997,100.061,63.060997,66.061,78.061,3.060997,-49.939003,-80.939,-70.939,-60.939003,-55.939003,-48.939003,-40.939003,-33.939003,-22.939003,-7.939003,0.06099701,4.060997,5.060997,12.060997,21.060997,24.060997,26.060997,28.060997,34.060997,38.060997,38.060997,38.060997,38.060997,31.060997,25.060997,20.060997,11.060997,2.060997,-11.939003,-18.939003,-22.939003,-31.939003,-42.939003,-54.939003,-61.939003,-69.939,-79.939,-92.939,-102.939,-102.939,-102.939,-102.939,-102.939,-101.939,-101.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-95.939,-92.939,-88.939,-75.939,-72.939,-80.939,-50.939003,-16.939003,21.060997,21.060997,7.060997,0.06099701,-4.939003,-11.939003,-6.939003,-0.939003,7.060997,12.060997,18.060997,22.060997,27.060997,33.060997,33.060997,33.060997,32.060997,28.060997,24.060997,26.060997,23.060997,14.060997,5.060997,-5.939003,-18.939003,-25.939003,-32.939003,-38.939003,-47.939003,-59.939003,-69.939,-80.939,-93.939,-100.939,-103.939,-103.939,-102.939,-101.939,-101.939,-101.939,-101.939,-100.939,-99.939,-99.939,-99.939,-98.939,-97.939,-97.939,-96.939,-92.939,-87.939,-82.939,-72.939,-58.939003,-55.939003,-51.939003,-42.939003,-34.939003,-24.939003,-14.939003,-6.939003,-1.939003,-7.939003,-16.939003,-30.939003,-36.939003,-38.939003,-31.939003,-27.939003,-30.939003,-30.939003,-31.939003,-31.939003,-28.939003,-21.939003,-9.939003,-2.939003,-0.939003,2.060997,6.060997,16.060997,24.060997,31.060997,25.060997,28.060997,36.060997,47.060997,56.060997,63.060997,63.060997,61.060997,67.061,72.061,76.061,85.061,91.061,94.061,81.061,62.060997,46.060997,63.060997,111.061,21.060997,-62.939003,-93.939,-90.939,-69.939,-12.939003,13.060997,9.060997,52.060997,78.061,54.060997,-19.939003,-102.939,-95.939,-88.939,-81.939,-29.939003,11.060997,16.060997,20.060997,24.060997,29.060997,34.060997,38.060997,32.060997,28.060997,27.060997,27.060997,25.060997,21.060997,17.060997,12.060997,5.060997,-2.939003,-8.939003,-15.939003,-22.939003,-35.939003,-48.939003,-60.939003,-66.939,-74.939,-87.939,-94.939,-99.939,-98.939,-94.939,-86.939,-80.939,-55.939003,15.060997,28.060997,22.060997,86.061,85.061,22.060997,-51.939003,-99.939,-84.939,-79.939,-77.939,-79.939,-48.939003,14.060997,10.060997,-0.939003,-10.939003,-16.939003,-19.939003,-8.939003,0.06099701,4.060997,11.060997,17.060997,17.060997,22.060997,29.060997,31.060997,30.060997,27.060997,28.060997,29.060997,31.060997,24.060997,12.060997,6.060997,2.060997,2.060997,-8.939003,-19.939003,-27.939003,-34.939003,-42.939003,-59.939003,-73.939,-84.939,-65.939,-56.939003,-75.939,-89.939,-100.939,-92.939,-91.939,-97.939,-97.939,-96.939,-94.939,-93.939,-91.939,-77.939,-71.939,-74.939,-76.939,-78.939,-76.939,-57.939003,-32.939003,-42.939003,-50.939003,-55.939003,-61.939003,-64.939,-61.939003,-63.939003,-69.939,-73.939,-78.939,-80.939,-75.939,-72.939,-76.939,-78.939,-79.939,-87.939,-92.939,-97.939,-95.939,-93.939,-88.939,-88.939,-87.939,-85.939,-85.939,-84.939,-84.939,-83.939,-80.939,-78.939,-77.939,-74.939,-74.939,-79.939,-80.939,-79.939,-74.939,-75.939,-78.939,-78.939,-80.939,-83.939,-81.939,-79.939,-77.939,-75.939,-73.939,-72.939,-73.939,-76.939,-74.939,-73.939,-75.939,-78.939,-81.939,-79.939,-80.939,-83.939,-82.939,-80.939,-79.939,-77.939,-75.939,-75.939,-74.939,-72.939,-68.939,-68.939,-79.939,-81.939,-79.939,-77.939,-77.939,-78.939,-79.939,-78.939,-74.939,-76.939,-80.939,-81.939,-78.939,-73.939,-76.939,-79.939,-80.939,-80.939,-78.939,-81.939,-81.939,-80.939,-78.939,-78.939,-80.939,-80.939,-78.939,-77.939,-78.939,-81.939,-83.939,-84.939,-82.939,-82.939,-82.939,-79.939,-78.939,-79.939,-77.939,-75.939,-76.939,-77.939,-78.939,-78.939,-80.939,-84.939,-84.939,-85.939,-85.939,-83.939,-81.939,-82.939,-83.939,-85.939,-85.939,-84.939,-83.939,-80.939,-78.939,-83.939,-86.939,-88.939,-90.939,-90.939,-86.939,-87.939,-88.939,-91.939,-91.939,-90.939,-82.939,-77.939,-80.939,-78.939,-74.939,-70.939,-73.939,-81.939,-80.939,-78.939,-76.939,-75.939,-75.939,-79.939,-80.939,-78.939,-76.939,-73.939,-72.939,-72.939,-72.939,-70.939,-69.939,-68.939,-70.939,-72.939,-74.939,-77.939,-80.939,-81.939,-80.939,-79.939,-76.939,-76.939,-82.939,-85.939,-86.939,-83.939,-84.939,-88.939,-85.939,-83.939,-86.939,-85.939,-83.939,-80.939,-78.939,-77.939,-79.939,-81.939,-79.939,-79.939,-79.939,-78.939,-77.939,-78.939,-81.939,-83.939,-82.939,-81.939,-82.939,-84.939,-84.939,-83.939,-84.939,-86.939,-90.939,-91.939,-90.939,-89.939,-90.939,-93.939,-87.939,-84.939,-86.939,-87.939,-87.939,-84.939,-83.939,-84.939,-81.939,-80.939,-80.939,-81.939,-84.939,-87.939,-89.939,-91.939,-88.939,-84.939,-82.939,-81.939,-82.939,-80.939,-79.939,-78.939,-79.939,-79.939,-81.939,-83.939,-86.939,-85.939,-84.939,-85.939,-85.939,-86.939,-88.939,-87.939,-85.939,-86.939,-86.939,-83.939,-82.939,-81.939,-82.939,-80.939,-77.939,-81.939,-83.939,-82.939,-86.939,-91.939,-93.939,-92.939,-90.939,-91.939,-88.939,-83.939,-84.939,-85.939,-80.939,-80.939,-82.939,-80.939,-79.939,-78.939,-80.939,-82.939,-82.939,-83.939,-84.939,-87.939,-88.939,-88.939,-87.939,-86.939,-88.939,-88.939,-88.939,-89.939,-88.939,-85.939,-85.939,-85.939,-87.939,-86.939,-85.939,-84.939,-84.939,-85.939,-86.939,-86.939,-86.939,-85.939,-84.939,-86.939,-88.939,-90.939,-90.939,-90.939,-91.939,-90.939,-88.939,-90.939,-90.939,-89.939,-88.939,-87.939,-83.939,-86.939,-92.939,-91.939,-92.939,-93.939,-91.939,-89.939,-85.939,-86.939,-89.939,-86.939,-86.939,-88.939,-86.939,-86.939,-91.939,-93.939,-94.939,-91.939,-90.939,-93.939,-93.939,-93.939,-93.939,-93.939,-94.939,-98.939,-101.939,-101.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-102.939,-96.939,-73.939,-70.939,-75.939,-75.939,-78.939,-83.939,-85.939,-89.939,-100.939,-102.939,-101.939,-96.939,-92.939,-88.939,-93.939,-92.939,-74.939,-67.939,-63.939003,-66.939,-67.939,-65.939,-66.939,-66.939,-64.939,-65.939,-68.939,-72.939,-75.939,-75.939,-74.939,-75.939,-80.939,-75.939,-68.939,-81.939,-88.939,-89.939,-84.939,-80.939,-85.939,-75.939,-63.939003,-70.939,-75.939,-78.939,-79.939,-80.939,-84.939,-82.939,-77.939,-61.939003,-54.939003,-56.939003,-60.939003,-64.939,-66.939,-65.939,-63.939003,-68.939,-70.939,-69.939,-71.939,-73.939,-74.939,-72.939,-68.939,-55.939003,-52.939003,-60.939003,-72.939,-81.939,44.060997,48.060997,50.060997,23.060997,14.060997,26.060997,62.060997,87.061,77.061,77.061,81.061,81.061,82.061,83.061,82.061,81.061,80.061,78.061,74.061,69.061,33.060997,-31.939003,-64.939,-82.939,-71.939,-62.939003,-55.939003,-54.939003,-49.939003,-42.939003,-36.939003,-29.939003,-22.939003,-17.939003,-15.939003,-9.939003,-3.939003,2.060997,6.060997,8.060997,9.060997,12.060997,15.060997,12.060997,11.060997,11.060997,7.060997,2.060997,0.06099701,-3.939003,-7.939003,-14.939003,-18.939003,-21.939003,-29.939003,-35.939003,-38.939003,-43.939003,-47.939003,-58.939003,-64.939,-65.939,-67.939,-70.939,-73.939,-76.939,-78.939,-82.939,-86.939,-89.939,-91.939,-95.939,-98.939,-100.939,-99.939,-96.939,-93.939,-90.939,-86.939,-82.939,-74.939,-71.939,-70.939,-64.939,-54.939003,-41.939003,-0.939003,26.060997,9.060997,4.060997,4.060997,-7.939003,-10.939003,-6.939003,0.06099701,4.060997,3.060997,8.060997,14.060997,14.060997,16.060997,19.060997,19.060997,17.060997,12.060997,9.060997,6.060997,-2.939003,-9.939003,-13.939003,-19.939003,-25.939003,-31.939003,-36.939003,-41.939003,-53.939003,-59.939003,-62.939003,-64.939,-66.939,-72.939,-75.939,-76.939,-76.939,-80.939,-87.939,-89.939,-91.939,-94.939,-97.939,-99.939,-95.939,-93.939,-91.939,-87.939,-82.939,-77.939,-69.939,-62.939003,-58.939003,-53.939003,-47.939003,-41.939003,-36.939003,-34.939003,-30.939003,-26.939003,-13.939003,-4.939003,-0.939003,0.06099701,-3.939003,-13.939003,-15.939003,-13.939003,-15.939003,-8.939003,5.060997,11.060997,14.060997,15.060997,11.060997,5.060997,-2.939003,-9.939003,-16.939003,-23.939003,-29.939003,-31.939003,-39.939003,-48.939003,-54.939003,-59.939003,-64.939,-67.939,-70.939,-75.939,-77.939,-79.939,-81.939,-84.939,-88.939,-91.939,-95.939,-98.939,-100.939,-100.939,-100.939,-95.939,-86.939,-86.939,-84.939,-79.939,-71.939,-63.939003,-60.939003,-54.939003,-45.939003,-39.939003,-33.939003,-27.939003,-22.939003,-18.939003,-11.939003,-4.939003,0.06099701,5.060997,10.060997,16.060997,17.060997,17.060997,6.060997,1.060997,3.060997,-0.939003,-5.939003,-9.939003,-8.939003,-5.939003,3.060997,9.060997,11.060997,14.060997,19.060997,27.060997,36.060997,43.060997,19.060997,12.060997,21.060997,19.060997,27.060997,62.060997,78.061,84.061,68.061,63.060997,69.061,72.061,73.061,70.061,68.061,64.061,56.060997,51.060997,47.060997,52.060997,53.060997,46.060997,35.060997,21.060997,7.060997,12.060997,33.060997,-21.939003,-73.939,-96.939,-94.939,-80.939,-40.939003,-23.939003,-27.939003,2.060997,18.060997,-8.939003,-55.939003,-103.939,-100.939,-96.939,-90.939,-34.939003,10.060997,15.060997,9.060997,0.06099701,-6.939003,-12.939003,-16.939003,-26.939003,-34.939003,-42.939003,-46.939003,-51.939003,-58.939003,-63.939003,-64.939,-67.939,-69.939,-71.939,-74.939,-76.939,-80.939,-85.939,-89.939,-91.939,-93.939,-96.939,-95.939,-94.939,-93.939,-92.939,-92.939,-89.939,-70.939,-18.939003,-6.939003,-7.939003,43.060997,45.060997,-1.939003,-60.939003,-100.939,-88.939,-89.939,-90.939,-43.939003,-13.939003,-0.939003,-8.939003,-16.939003,-15.939003,-5.939003,7.060997,12.060997,15.060997,11.060997,9.060997,4.060997,-7.939003,-12.939003,-13.939003,-17.939003,-25.939003,-37.939003,-41.939003,-45.939003,-51.939003,-57.939003,-64.939,-66.939,-68.939,-67.939,-71.939,-75.939,-78.939,-80.939,-83.939,-88.939,-92.939,-97.939,-78.939,-65.939,-66.939,-71.939,-76.939,-57.939003,-47.939003,-44.939003,-37.939003,-26.939003,-7.939003,-5.939003,-12.939003,-18.939003,-31.939003,-51.939003,-65.939,-75.939,-78.939,-77.939,-73.939,-77.939,-81.939,-85.939,-88.939,-89.939,-87.939,-86.939,-88.939,-86.939,-85.939,-83.939,-81.939,-79.939,-79.939,-80.939,-82.939,-84.939,-85.939,-85.939,-84.939,-84.939,-83.939,-83.939,-82.939,-80.939,-80.939,-83.939,-83.939,-80.939,-77.939,-79.939,-83.939,-80.939,-77.939,-76.939,-77.939,-77.939,-74.939,-74.939,-76.939,-77.939,-79.939,-80.939,-78.939,-77.939,-77.939,-76.939,-74.939,-75.939,-76.939,-79.939,-75.939,-74.939,-78.939,-77.939,-74.939,-77.939,-79.939,-78.939,-79.939,-81.939,-82.939,-81.939,-80.939,-80.939,-79.939,-77.939,-75.939,-75.939,-77.939,-80.939,-81.939,-79.939,-79.939,-81.939,-80.939,-79.939,-79.939,-77.939,-75.939,-79.939,-81.939,-80.939,-83.939,-84.939,-82.939,-79.939,-75.939,-76.939,-77.939,-77.939,-78.939,-79.939,-79.939,-79.939,-79.939,-77.939,-78.939,-82.939,-81.939,-79.939,-77.939,-78.939,-80.939,-79.939,-79.939,-78.939,-77.939,-77.939,-79.939,-80.939,-79.939,-79.939,-77.939,-74.939,-78.939,-82.939,-83.939,-83.939,-83.939,-85.939,-86.939,-83.939,-84.939,-85.939,-82.939,-77.939,-75.939,-81.939,-84.939,-83.939,-85.939,-85.939,-83.939,-83.939,-83.939,-86.939,-88.939,-87.939,-81.939,-78.939,-77.939,-78.939,-79.939,-79.939,-80.939,-78.939,-79.939,-81.939,-83.939,-84.939,-82.939,-82.939,-79.939,-75.939,-77.939,-77.939,-77.939,-77.939,-77.939,-76.939,-74.939,-71.939,-72.939,-73.939,-75.939,-76.939,-77.939,-76.939,-76.939,-76.939,-77.939,-78.939,-79.939,-82.939,-85.939,-84.939,-86.939,-90.939,-85.939,-82.939,-85.939,-84.939,-82.939,-81.939,-80.939,-78.939,-81.939,-84.939,-82.939,-82.939,-82.939,-79.939,-76.939,-75.939,-78.939,-81.939,-79.939,-78.939,-77.939,-79.939,-80.939,-79.939,-81.939,-84.939,-88.939,-88.939,-87.939,-86.939,-87.939,-89.939,-88.939,-87.939,-87.939,-86.939,-86.939,-84.939,-84.939,-85.939,-85.939,-86.939,-85.939,-85.939,-85.939,-85.939,-86.939,-87.939,-84.939,-81.939,-80.939,-81.939,-83.939,-83.939,-82.939,-82.939,-82.939,-82.939,-84.939,-85.939,-87.939,-85.939,-84.939,-84.939,-86.939,-87.939,-89.939,-87.939,-84.939,-85.939,-85.939,-83.939,-82.939,-81.939,-81.939,-79.939,-78.939,-80.939,-81.939,-81.939,-84.939,-88.939,-90.939,-91.939,-93.939,-94.939,-91.939,-86.939,-87.939,-86.939,-81.939,-81.939,-82.939,-82.939,-81.939,-79.939,-82.939,-84.939,-82.939,-82.939,-83.939,-88.939,-91.939,-90.939,-88.939,-87.939,-87.939,-89.939,-91.939,-90.939,-88.939,-86.939,-86.939,-86.939,-88.939,-87.939,-86.939,-89.939,-89.939,-88.939,-89.939,-88.939,-87.939,-86.939,-85.939,-85.939,-86.939,-87.939,-88.939,-88.939,-88.939,-88.939,-87.939,-87.939,-87.939,-89.939,-87.939,-85.939,-83.939,-86.939,-91.939,-90.939,-90.939,-91.939,-90.939,-90.939,-88.939,-89.939,-90.939,-87.939,-88.939,-91.939,-89.939,-87.939,-88.939,-91.939,-93.939,-94.939,-93.939,-90.939,-88.939,-87.939,-90.939,-92.939,-94.939,-92.939,-92.939,-93.939,-91.939,-91.939,-95.939,-95.939,-94.939,-91.939,-91.939,-96.939,-96.939,-93.939,-82.939,-77.939,-73.939,-68.939,-68.939,-73.939,-82.939,-90.939,-99.939,-91.939,-78.939,-79.939,-81.939,-83.939,-86.939,-86.939,-77.939,-72.939,-70.939,-75.939,-80.939,-82.939,-83.939,-84.939,-85.939,-86.939,-86.939,-79.939,-74.939,-70.939,-67.939,-64.939,-57.939003,-58.939003,-63.939003,-80.939,-85.939,-77.939,-79.939,-82.939,-83.939,-76.939,-68.939,-71.939,-76.939,-83.939,-79.939,-75.939,-75.939,-73.939,-70.939,-58.939003,-53.939003,-53.939003,-55.939003,-57.939003,-61.939003,-64.939,-66.939,-72.939,-72.939,-66.939,-64.939,-63.939003,-67.939,-67.939,-66.939,-66.939,-67.939,-69.939,-75.939,-79.939,70.061,75.061,79.061,39.060997,20.060997,21.060997,63.060997,92.061,76.061,71.061,69.061,66.061,62.060997,60.060997,55.060997,50.060997,47.060997,42.060997,35.060997,30.060997,12.060997,-18.939003,-32.939003,-38.939003,-28.939003,-18.939003,-10.939003,-11.939003,-8.939003,-3.939003,-0.939003,2.060997,5.060997,7.060997,7.060997,7.060997,7.060997,6.060997,8.060997,8.060997,4.060997,1.060997,-1.939003,-8.939003,-14.939003,-19.939003,-25.939003,-28.939003,-27.939003,-18.939003,-11.939003,-40.939003,-58.939003,-66.939,-73.939,-79.939,-84.939,-86.939,-89.939,-97.939,-100.939,-98.939,-96.939,-94.939,-90.939,-89.939,-88.939,-85.939,-84.939,-83.939,-81.939,-78.939,-73.939,-71.939,-68.939,-62.939003,-56.939003,-50.939003,-45.939003,-39.939003,-32.939003,-28.939003,-24.939003,-19.939003,-16.939003,-15.939003,-12.939003,-10.939003,-14.939003,-18.939003,-19.939003,-10.939003,-0.939003,10.060997,11.060997,9.060997,6.060997,5.060997,5.060997,1.060997,-3.939003,-8.939003,-11.939003,-16.939003,-20.939003,-27.939003,-33.939003,-43.939003,-51.939003,-56.939003,-64.939,-71.939,-78.939,-83.939,-88.939,-96.939,-99.939,-99.939,-97.939,-94.939,-91.939,-89.939,-88.939,-82.939,-80.939,-81.939,-78.939,-75.939,-73.939,-70.939,-67.939,-59.939003,-54.939003,-52.939003,-47.939003,-40.939003,-33.939003,-25.939003,-16.939003,-13.939003,-8.939003,-1.939003,3.060997,6.060997,4.060997,5.060997,7.060997,11.060997,15.060997,17.060997,9.060997,-0.939003,-8.939003,-4.939003,3.060997,11.060997,6.060997,-12.939003,-15.939003,-17.939003,-18.939003,-25.939003,-32.939003,-42.939003,-51.939003,-60.939003,-68.939,-73.939,-75.939,-82.939,-91.939,-95.939,-98.939,-98.939,-96.939,-94.939,-91.939,-89.939,-86.939,-84.939,-81.939,-79.939,-76.939,-74.939,-68.939,-66.939,-65.939,-61.939003,-54.939003,-43.939003,-43.939003,-42.939003,-35.939003,-25.939003,-16.939003,-15.939003,-10.939003,-1.939003,1.060997,4.060997,7.060997,9.060997,9.060997,11.060997,11.060997,8.060997,10.060997,10.060997,8.060997,4.060997,0.06099701,-3.939003,-6.939003,-6.939003,1.060997,12.060997,29.060997,34.060997,36.060997,40.060997,45.060997,49.060997,52.060997,56.060997,61.060997,71.061,79.061,35.060997,17.060997,25.060997,21.060997,29.060997,62.060997,78.061,83.061,63.060997,53.060997,54.060997,52.060997,49.060997,42.060997,38.060997,35.060997,23.060997,15.060997,11.060997,13.060997,14.060997,7.060997,2.060997,-2.939003,-10.939003,-11.939003,-6.939003,-41.939003,-76.939,-96.939,-95.939,-84.939,-56.939003,-43.939003,-48.939003,-28.939003,-18.939003,-38.939003,-70.939,-103.939,-102.939,-100.939,-97.939,-57.939003,-25.939003,-21.939003,-29.939003,-40.939003,-49.939003,-56.939003,-62.939003,-70.939,-78.939,-85.939,-88.939,-91.939,-95.939,-95.939,-93.939,-89.939,-86.939,-84.939,-82.939,-79.939,-75.939,-74.939,-74.939,-70.939,-67.939,-63.939003,-56.939003,-50.939003,-69.939,-84.939,-97.939,-94.939,-80.939,-45.939003,-38.939003,-38.939003,-1.939003,1.060997,-28.939003,-71.939,-100.939,-92.939,-94.939,-94.939,-40.939003,-7.939003,6.060997,1.060997,0.06099701,11.060997,6.060997,-6.939003,-12.939003,-16.939003,-21.939003,-26.939003,-33.939003,-45.939003,-52.939003,-56.939003,-61.939003,-68.939,-79.939,-83.939,-87.939,-91.939,-92.939,-92.939,-90.939,-87.939,-83.939,-81.939,-78.939,-75.939,-72.939,-69.939,-67.939,-64.939,-60.939003,-49.939003,-40.939003,-35.939003,-37.939003,-41.939003,-30.939003,-27.939003,-28.939003,-21.939003,-12.939003,0.06099701,-2.939003,-13.939003,-24.939003,-38.939003,-57.939003,-69.939,-78.939,-81.939,-87.939,-91.939,-91.939,-92.939,-94.939,-95.939,-94.939,-92.939,-91.939,-91.939,-88.939,-85.939,-82.939,-81.939,-80.939,-79.939,-80.939,-81.939,-81.939,-80.939,-80.939,-79.939,-80.939,-81.939,-82.939,-81.939,-79.939,-79.939,-82.939,-81.939,-79.939,-77.939,-79.939,-84.939,-82.939,-79.939,-74.939,-76.939,-76.939,-76.939,-75.939,-75.939,-76.939,-77.939,-76.939,-77.939,-77.939,-79.939,-78.939,-76.939,-75.939,-76.939,-79.939,-75.939,-74.939,-78.939,-76.939,-71.939,-76.939,-77.939,-74.939,-77.939,-80.939,-82.939,-82.939,-81.939,-81.939,-81.939,-81.939,-79.939,-78.939,-78.939,-81.939,-84.939,-82.939,-80.939,-80.939,-78.939,-77.939,-81.939,-78.939,-73.939,-77.939,-80.939,-81.939,-83.939,-85.939,-81.939,-78.939,-75.939,-77.939,-77.939,-77.939,-78.939,-79.939,-78.939,-79.939,-81.939,-79.939,-79.939,-81.939,-78.939,-76.939,-75.939,-76.939,-78.939,-78.939,-78.939,-79.939,-79.939,-79.939,-80.939,-80.939,-80.939,-79.939,-76.939,-71.939,-76.939,-79.939,-81.939,-83.939,-85.939,-86.939,-85.939,-82.939,-83.939,-83.939,-80.939,-78.939,-76.939,-79.939,-80.939,-77.939,-80.939,-82.939,-82.939,-82.939,-83.939,-84.939,-84.939,-82.939,-80.939,-78.939,-74.939,-75.939,-77.939,-81.939,-80.939,-75.939,-77.939,-80.939,-84.939,-84.939,-82.939,-80.939,-77.939,-74.939,-76.939,-77.939,-78.939,-78.939,-78.939,-79.939,-77.939,-74.939,-75.939,-76.939,-76.939,-76.939,-75.939,-74.939,-74.939,-75.939,-77.939,-78.939,-77.939,-80.939,-82.939,-84.939,-86.939,-88.939,-84.939,-82.939,-84.939,-84.939,-83.939,-83.939,-82.939,-80.939,-83.939,-85.939,-83.939,-82.939,-82.939,-79.939,-76.939,-73.939,-77.939,-80.939,-77.939,-76.939,-75.939,-76.939,-77.939,-77.939,-80.939,-83.939,-85.939,-85.939,-84.939,-85.939,-86.939,-86.939,-87.939,-87.939,-86.939,-86.939,-87.939,-86.939,-85.939,-84.939,-86.939,-89.939,-89.939,-87.939,-84.939,-84.939,-84.939,-85.939,-82.939,-80.939,-80.939,-82.939,-84.939,-84.939,-84.939,-84.939,-83.939,-83.939,-84.939,-86.939,-88.939,-86.939,-85.939,-83.939,-85.939,-86.939,-86.939,-85.939,-83.939,-84.939,-84.939,-82.939,-81.939,-81.939,-79.939,-78.939,-78.939,-79.939,-80.939,-80.939,-82.939,-84.939,-87.939,-89.939,-91.939,-92.939,-90.939,-87.939,-86.939,-85.939,-81.939,-80.939,-81.939,-84.939,-84.939,-81.939,-83.939,-84.939,-82.939,-82.939,-84.939,-87.939,-90.939,-90.939,-89.939,-89.939,-88.939,-89.939,-91.939,-89.939,-87.939,-87.939,-87.939,-87.939,-89.939,-87.939,-85.939,-91.939,-93.939,-90.939,-89.939,-88.939,-88.939,-87.939,-86.939,-84.939,-84.939,-85.939,-86.939,-87.939,-88.939,-88.939,-86.939,-85.939,-85.939,-87.939,-85.939,-83.939,-84.939,-87.939,-90.939,-89.939,-89.939,-90.939,-90.939,-90.939,-90.939,-90.939,-91.939,-90.939,-90.939,-92.939,-91.939,-90.939,-90.939,-91.939,-92.939,-94.939,-93.939,-89.939,-87.939,-85.939,-89.939,-90.939,-91.939,-89.939,-89.939,-89.939,-86.939,-85.939,-89.939,-89.939,-87.939,-84.939,-85.939,-90.939,-91.939,-91.939,-90.939,-87.939,-81.939,-73.939,-70.939,-71.939,-78.939,-87.939,-95.939,-79.939,-56.939003,-64.939,-73.939,-80.939,-84.939,-85.939,-82.939,-81.939,-80.939,-85.939,-89.939,-94.939,-94.939,-94.939,-95.939,-96.939,-95.939,-80.939,-69.939,-63.939003,-62.939003,-60.939003,-50.939003,-55.939003,-66.939,-80.939,-83.939,-72.939,-76.939,-81.939,-81.939,-78.939,-74.939,-77.939,-80.939,-84.939,-79.939,-74.939,-72.939,-69.939,-66.939,-57.939003,-52.939003,-50.939003,-51.939003,-53.939003,-57.939003,-60.939003,-62.939003,-68.939,-69.939,-62.939003,-61.939003,-61.939003,-64.939,-66.939,-66.939,-69.939,-71.939,-71.939,-73.939,-75.939,73.061,77.061,80.061,45.060997,19.060997,1.060997,32.060997,55.060997,45.060997,33.060997,20.060997,11.060997,5.060997,3.060997,-3.939003,-10.939003,-15.939003,-20.939003,-23.939003,-28.939003,-26.939003,-15.939003,10.060997,29.060997,28.060997,31.060997,34.060997,30.060997,26.060997,21.060997,16.060997,9.060997,1.060997,-3.939003,-7.939003,-17.939003,-27.939003,-38.939003,-42.939003,-48.939003,-59.939003,-67.939,-74.939,-83.939,-89.939,-94.939,-96.939,-93.939,-84.939,-49.939003,-13.939003,-66.939,-95.939,-100.939,-100.939,-101.939,-101.939,-101.939,-102.939,-101.939,-96.939,-88.939,-82.939,-77.939,-69.939,-63.939003,-58.939003,-52.939003,-47.939003,-42.939003,-36.939003,-29.939003,-18.939003,-12.939003,-7.939003,-0.939003,8.060997,17.060997,22.060997,25.060997,24.060997,30.060997,37.060997,35.060997,32.060997,28.060997,6.060997,-9.939003,-8.939003,-1.939003,5.060997,-7.939003,-18.939003,-28.939003,-36.939003,-43.939003,-46.939003,-55.939003,-66.939,-73.939,-82.939,-91.939,-94.939,-95.939,-95.939,-96.939,-97.939,-98.939,-98.939,-99.939,-99.939,-100.939,-101.939,-101.939,-102.939,-99.939,-94.939,-90.939,-88.939,-81.939,-67.939,-61.939003,-57.939003,-48.939003,-43.939003,-38.939003,-29.939003,-20.939003,-16.939003,-10.939003,-5.939003,7.060997,14.060997,14.060997,18.060997,24.060997,30.060997,33.060997,36.060997,35.060997,36.060997,37.060997,36.060997,31.060997,19.060997,14.060997,12.060997,0.06099701,-11.939003,-25.939003,-23.939003,-7.939003,38.060997,53.060997,58.060997,81.061,40.060997,-66.939,-87.939,-95.939,-95.939,-96.939,-97.939,-97.939,-98.939,-99.939,-100.939,-100.939,-100.939,-101.939,-102.939,-98.939,-93.939,-88.939,-81.939,-76.939,-68.939,-60.939003,-51.939003,-45.939003,-38.939003,-31.939003,-25.939003,-18.939003,-3.939003,0.06099701,2.060997,13.060997,21.060997,25.060997,25.060997,26.060997,31.060997,37.060997,40.060997,36.060997,33.060997,33.060997,26.060997,17.060997,8.060997,3.060997,-1.939003,-14.939003,-25.939003,-34.939003,-42.939003,-51.939003,-64.939,-72.939,-75.939,-44.939003,-30.939003,-31.939003,-2.939003,35.060997,84.061,93.061,85.061,81.061,80.061,83.061,82.061,80.061,72.061,78.061,84.061,40.060997,14.060997,10.060997,10.060997,12.060997,17.060997,23.060997,28.060997,10.060997,-1.939003,-7.939003,-11.939003,-15.939003,-22.939003,-25.939003,-25.939003,-31.939003,-33.939003,-31.939003,-29.939003,-25.939003,-21.939003,-16.939003,-11.939003,-9.939003,-7.939003,-7.939003,-37.939003,-69.939,-92.939,-92.939,-82.939,-59.939003,-48.939003,-52.939003,-40.939003,-32.939003,-35.939003,-65.939,-103.939,-101.939,-101.939,-102.939,-99.939,-96.939,-96.939,-96.939,-97.939,-98.939,-99.939,-99.939,-100.939,-101.939,-101.939,-99.939,-96.939,-88.939,-80.939,-72.939,-62.939003,-53.939003,-47.939003,-39.939003,-31.939003,-20.939003,-16.939003,-15.939003,-5.939003,2.060997,9.060997,23.060997,33.060997,-26.939003,-70.939,-100.939,-96.939,-85.939,-66.939,-66.939,-71.939,-50.939003,-46.939003,-57.939003,-84.939,-101.939,-96.939,-93.939,-89.939,-71.939,-30.939003,33.060997,42.060997,50.060997,70.061,17.060997,-60.939003,-83.939,-95.939,-96.939,-96.939,-97.939,-98.939,-98.939,-99.939,-99.939,-98.939,-96.939,-97.939,-95.939,-88.939,-79.939,-69.939,-63.939003,-54.939003,-44.939003,-37.939003,-29.939003,-19.939003,-9.939003,-2.939003,2.060997,11.060997,24.060997,21.060997,17.060997,16.060997,12.060997,5.060997,-11.939003,-30.939003,-49.939003,-50.939003,-55.939003,-70.939,-84.939,-95.939,-93.939,-93.939,-92.939,-88.939,-85.939,-86.939,-87.939,-86.939,-83.939,-81.939,-81.939,-80.939,-79.939,-77.939,-78.939,-80.939,-80.939,-79.939,-77.939,-74.939,-73.939,-77.939,-78.939,-77.939,-76.939,-77.939,-82.939,-81.939,-81.939,-82.939,-84.939,-85.939,-83.939,-81.939,-80.939,-80.939,-80.939,-78.939,-78.939,-79.939,-80.939,-79.939,-74.939,-75.939,-77.939,-80.939,-78.939,-74.939,-75.939,-75.939,-73.939,-76.939,-79.939,-84.939,-82.939,-78.939,-74.939,-73.939,-74.939,-73.939,-73.939,-75.939,-74.939,-72.939,-75.939,-75.939,-71.939,-74.939,-78.939,-79.939,-79.939,-78.939,-79.939,-81.939,-82.939,-79.939,-78.939,-81.939,-85.939,-88.939,-85.939,-80.939,-75.939,-73.939,-74.939,-82.939,-80.939,-74.939,-76.939,-77.939,-76.939,-78.939,-80.939,-78.939,-79.939,-80.939,-82.939,-82.939,-79.939,-79.939,-78.939,-76.939,-79.939,-85.939,-84.939,-81.939,-76.939,-74.939,-73.939,-75.939,-75.939,-75.939,-75.939,-77.939,-82.939,-82.939,-81.939,-77.939,-77.939,-79.939,-77.939,-76.939,-76.939,-78.939,-78.939,-80.939,-83.939,-86.939,-83.939,-81.939,-81.939,-81.939,-80.939,-79.939,-81.939,-83.939,-78.939,-75.939,-72.939,-76.939,-80.939,-83.939,-85.939,-87.939,-84.939,-80.939,-76.939,-78.939,-78.939,-69.939,-68.939,-70.939,-74.939,-75.939,-74.939,-74.939,-75.939,-77.939,-76.939,-73.939,-75.939,-74.939,-73.939,-74.939,-74.939,-74.939,-74.939,-75.939,-78.939,-78.939,-77.939,-79.939,-79.939,-77.939,-75.939,-73.939,-74.939,-75.939,-76.939,-76.939,-76.939,-76.939,-77.939,-78.939,-82.939,-83.939,-82.939,-82.939,-83.939,-84.939,-86.939,-87.939,-84.939,-83.939,-85.939,-85.939,-84.939,-80.939,-79.939,-80.939,-79.939,-77.939,-74.939,-78.939,-81.939,-76.939,-75.939,-75.939,-75.939,-76.939,-77.939,-80.939,-82.939,-81.939,-81.939,-82.939,-85.939,-86.939,-84.939,-86.939,-86.939,-85.939,-87.939,-90.939,-89.939,-86.939,-82.939,-85.939,-89.939,-90.939,-87.939,-82.939,-84.939,-85.939,-84.939,-82.939,-81.939,-83.939,-85.939,-87.939,-84.939,-83.939,-84.939,-84.939,-84.939,-83.939,-85.939,-88.939,-88.939,-86.939,-82.939,-84.939,-84.939,-81.939,-80.939,-81.939,-83.939,-82.939,-79.939,-80.939,-80.939,-78.939,-77.939,-78.939,-78.939,-79.939,-81.939,-80.939,-81.939,-85.939,-85.939,-85.939,-87.939,-87.939,-86.939,-83.939,-81.939,-79.939,-79.939,-79.939,-85.939,-87.939,-85.939,-84.939,-83.939,-82.939,-83.939,-85.939,-84.939,-85.939,-88.939,-89.939,-90.939,-91.939,-90.939,-88.939,-86.939,-86.939,-88.939,-88.939,-88.939,-90.939,-87.939,-84.939,-91.939,-94.939,-91.939,-88.939,-86.939,-87.939,-86.939,-85.939,-83.939,-83.939,-84.939,-85.939,-86.939,-89.939,-89.939,-87.939,-84.939,-83.939,-84.939,-82.939,-81.939,-86.939,-88.939,-89.939,-88.939,-88.939,-91.939,-90.939,-90.939,-90.939,-91.939,-93.939,-93.939,-92.939,-90.939,-92.939,-94.939,-96.939,-94.939,-89.939,-90.939,-90.939,-90.939,-89.939,-88.939,-90.939,-88.939,-86.939,-89.939,-90.939,-90.939,-87.939,-84.939,-85.939,-84.939,-82.939,-82.939,-84.939,-85.939,-88.939,-92.939,-99.939,-101.939,-100.939,-90.939,-83.939,-76.939,-75.939,-78.939,-88.939,-64.939,-33.939003,-51.939003,-67.939,-77.939,-85.939,-90.939,-91.939,-92.939,-93.939,-94.939,-96.939,-99.939,-98.939,-96.939,-96.939,-96.939,-95.939,-73.939,-59.939003,-55.939003,-61.939003,-65.939,-59.939003,-65.939,-75.939,-81.939,-81.939,-74.939,-75.939,-77.939,-78.939,-79.939,-81.939,-87.939,-87.939,-82.939,-78.939,-75.939,-74.939,-70.939,-65.939,-58.939003,-53.939003,-48.939003,-49.939003,-50.939003,-53.939003,-53.939003,-52.939003,-57.939003,-60.939003,-59.939003,-63.939003,-66.939,-67.939,-67.939,-67.939,-64.939,-64.939,-66.939,-68.939,-70.939,27.060997,25.060997,22.060997,9.060997,-0.939003,-7.939003,-1.939003,2.060997,-4.939003,-9.939003,-14.939003,-16.939003,-16.939003,-13.939003,-12.939003,-13.939003,-12.939003,-10.939003,-7.939003,-5.939003,-5.939003,-5.939003,-7.939003,-10.939003,-13.939003,-18.939003,-24.939003,-30.939003,-35.939003,-40.939003,-46.939003,-53.939003,-55.939003,-58.939003,-60.939003,-65.939,-70.939,-77.939,-78.939,-81.939,-86.939,-90.939,-93.939,-94.939,-95.939,-96.939,-95.939,-90.939,-80.939,-47.939003,-13.939003,-49.939003,-70.939,-75.939,-70.939,-65.939,-60.939003,-57.939003,-56.939003,-52.939003,-45.939003,-34.939003,-29.939003,-23.939003,-15.939003,-11.939003,-8.939003,-3.939003,1.060997,5.060997,4.060997,4.060997,8.060997,10.060997,11.060997,8.060997,7.060997,6.060997,7.060997,4.060997,-6.939003,-9.939003,-11.939003,-19.939003,-26.939003,-33.939003,7.060997,39.060997,36.060997,47.060997,58.060997,23.060997,-18.939003,-67.939,-76.939,-79.939,-80.939,-85.939,-90.939,-93.939,-96.939,-98.939,-95.939,-92.939,-89.939,-87.939,-84.939,-82.939,-79.939,-74.939,-68.939,-62.939003,-60.939003,-57.939003,-53.939003,-47.939003,-40.939003,-32.939003,-27.939003,-20.939003,-10.939003,-4.939003,-0.939003,2.060997,6.060997,8.060997,10.060997,12.060997,15.060997,15.060997,12.060997,13.060997,11.060997,5.060997,3.060997,0.06099701,-5.939003,-10.939003,-15.939003,-19.939003,-24.939003,-29.939003,-35.939003,-41.939003,-48.939003,-50.939003,-52.939003,-58.939003,-64.939,-71.939,-61.939003,-25.939003,70.061,91.061,82.061,100.061,54.060997,-55.939003,-87.939,-100.939,-90.939,-86.939,-83.939,-79.939,-76.939,-74.939,-68.939,-60.939003,-53.939003,-49.939003,-47.939003,-41.939003,-35.939003,-30.939003,-23.939003,-16.939003,-10.939003,-3.939003,2.060997,6.060997,9.060997,11.060997,8.060997,8.060997,15.060997,14.060997,11.060997,10.060997,8.060997,3.060997,-2.939003,-8.939003,-10.939003,-15.939003,-21.939003,-29.939003,-34.939003,-36.939003,-43.939003,-50.939003,-54.939003,-57.939003,-59.939003,-65.939,-71.939,-75.939,-78.939,-83.939,-89.939,-92.939,-90.939,-50.939003,-29.939003,-27.939003,-5.939003,19.060997,50.060997,51.060997,41.060997,37.060997,33.060997,31.060997,28.060997,25.060997,20.060997,22.060997,24.060997,9.060997,-0.939003,-3.939003,-2.939003,-1.939003,-2.939003,-4.939003,-7.939003,-10.939003,-9.939003,-6.939003,-7.939003,-8.939003,-6.939003,-5.939003,-3.939003,-3.939003,-0.939003,6.060997,10.060997,15.060997,23.060997,24.060997,21.060997,5.060997,20.060997,64.061,14.060997,-40.939003,-87.939,-91.939,-78.939,-70.939,-61.939003,-52.939003,-45.939003,-43.939003,-50.939003,-75.939,-103.939,-102.939,-100.939,-100.939,-93.939,-86.939,-77.939,-73.939,-69.939,-62.939003,-56.939003,-50.939003,-48.939003,-47.939003,-42.939003,-35.939003,-30.939003,-24.939003,-17.939003,-6.939003,-2.939003,1.060997,5.060997,8.060997,11.060997,13.060997,13.060997,11.060997,9.060997,7.060997,8.060997,14.060997,16.060997,-37.939003,-76.939,-101.939,-96.939,-87.939,-75.939,-76.939,-78.939,-62.939003,-58.939003,-66.939,-87.939,-102.939,-98.939,-95.939,-91.939,-89.939,-43.939003,44.060997,54.060997,56.060997,59.060997,8.060997,-58.939003,-76.939,-82.939,-74.939,-69.939,-64.939,-60.939003,-56.939003,-51.939003,-47.939003,-40.939003,-33.939003,-32.939003,-28.939003,-19.939003,-13.939003,-10.939003,-5.939003,-0.939003,6.060997,4.060997,4.060997,11.060997,14.060997,14.060997,5.060997,-0.939003,-5.939003,-13.939003,-22.939003,-31.939003,-39.939003,-47.939003,-56.939003,-66.939,-76.939,-78.939,-80.939,-83.939,-86.939,-90.939,-89.939,-88.939,-86.939,-85.939,-84.939,-82.939,-82.939,-83.939,-81.939,-79.939,-77.939,-77.939,-77.939,-76.939,-77.939,-78.939,-78.939,-77.939,-78.939,-75.939,-74.939,-77.939,-77.939,-77.939,-75.939,-76.939,-81.939,-80.939,-79.939,-80.939,-81.939,-80.939,-80.939,-80.939,-81.939,-78.939,-77.939,-78.939,-79.939,-79.939,-78.939,-77.939,-74.939,-76.939,-78.939,-81.939,-79.939,-75.939,-76.939,-77.939,-77.939,-77.939,-77.939,-78.939,-77.939,-75.939,-75.939,-75.939,-74.939,-76.939,-77.939,-75.939,-74.939,-74.939,-76.939,-75.939,-72.939,-75.939,-77.939,-78.939,-79.939,-78.939,-78.939,-79.939,-80.939,-80.939,-79.939,-79.939,-80.939,-82.939,-79.939,-76.939,-73.939,-74.939,-75.939,-79.939,-77.939,-73.939,-75.939,-78.939,-78.939,-79.939,-81.939,-81.939,-82.939,-82.939,-77.939,-77.939,-80.939,-80.939,-79.939,-79.939,-80.939,-82.939,-81.939,-80.939,-79.939,-78.939,-78.939,-76.939,-77.939,-79.939,-76.939,-77.939,-82.939,-83.939,-82.939,-78.939,-79.939,-81.939,-79.939,-78.939,-78.939,-78.939,-79.939,-81.939,-84.939,-85.939,-82.939,-80.939,-80.939,-82.939,-82.939,-79.939,-81.939,-84.939,-77.939,-74.939,-76.939,-78.939,-79.939,-81.939,-84.939,-86.939,-83.939,-79.939,-76.939,-78.939,-79.939,-75.939,-74.939,-76.939,-77.939,-79.939,-80.939,-79.939,-78.939,-76.939,-76.939,-75.939,-77.939,-77.939,-74.939,-77.939,-80.939,-79.939,-77.939,-75.939,-78.939,-79.939,-79.939,-80.939,-78.939,-74.939,-72.939,-71.939,-72.939,-73.939,-73.939,-74.939,-75.939,-74.939,-74.939,-73.939,-77.939,-78.939,-77.939,-78.939,-79.939,-81.939,-83.939,-85.939,-82.939,-81.939,-85.939,-86.939,-85.939,-82.939,-82.939,-83.939,-84.939,-82.939,-78.939,-81.939,-83.939,-78.939,-77.939,-76.939,-77.939,-78.939,-79.939,-79.939,-79.939,-81.939,-80.939,-79.939,-84.939,-85.939,-82.939,-84.939,-85.939,-86.939,-86.939,-85.939,-85.939,-85.939,-85.939,-86.939,-87.939,-87.939,-86.939,-85.939,-85.939,-85.939,-86.939,-83.939,-82.939,-83.939,-84.939,-85.939,-85.939,-84.939,-83.939,-84.939,-83.939,-82.939,-82.939,-83.939,-86.939,-86.939,-83.939,-83.939,-83.939,-80.939,-79.939,-79.939,-82.939,-83.939,-81.939,-82.939,-83.939,-81.939,-80.939,-81.939,-80.939,-80.939,-80.939,-81.939,-82.939,-85.939,-84.939,-82.939,-85.939,-87.939,-88.939,-84.939,-81.939,-81.939,-83.939,-86.939,-86.939,-85.939,-84.939,-82.939,-82.939,-83.939,-84.939,-84.939,-83.939,-83.939,-83.939,-86.939,-88.939,-88.939,-89.939,-89.939,-88.939,-87.939,-86.939,-88.939,-89.939,-88.939,-86.939,-86.939,-90.939,-91.939,-88.939,-87.939,-87.939,-86.939,-83.939,-80.939,-81.939,-82.939,-85.939,-86.939,-86.939,-87.939,-86.939,-85.939,-85.939,-85.939,-86.939,-85.939,-85.939,-88.939,-89.939,-89.939,-88.939,-88.939,-89.939,-89.939,-89.939,-89.939,-90.939,-90.939,-92.939,-92.939,-91.939,-90.939,-90.939,-94.939,-93.939,-92.939,-89.939,-88.939,-89.939,-87.939,-87.939,-89.939,-89.939,-87.939,-89.939,-90.939,-91.939,-89.939,-87.939,-87.939,-85.939,-84.939,-86.939,-87.939,-86.939,-90.939,-93.939,-97.939,-99.939,-100.939,-98.939,-92.939,-85.939,-83.939,-84.939,-91.939,-82.939,-67.939,-78.939,-87.939,-91.939,-93.939,-94.939,-94.939,-94.939,-95.939,-94.939,-96.939,-99.939,-96.939,-95.939,-98.939,-99.939,-98.939,-77.939,-64.939,-58.939003,-60.939003,-62.939003,-66.939,-74.939,-84.939,-83.939,-79.939,-69.939,-68.939,-71.939,-80.939,-82.939,-82.939,-83.939,-84.939,-84.939,-80.939,-76.939,-71.939,-65.939,-59.939003,-53.939003,-48.939003,-44.939003,-47.939003,-49.939003,-51.939003,-50.939003,-50.939003,-54.939003,-54.939003,-49.939003,-57.939003,-64.939,-64.939,-63.939003,-62.939003,-64.939,-65.939,-65.939,-65.939,-64.939,-11.939003,-14.939003,-19.939003,-16.939003,-13.939003,-11.939003,-18.939003,-24.939003,-26.939003,-24.939003,-20.939003,-15.939003,-11.939003,-5.939003,0.06099701,6.060997,10.060997,16.060997,23.060997,31.060997,27.060997,10.060997,-27.939003,-57.939003,-60.939003,-69.939,-79.939,-85.939,-89.939,-91.939,-95.939,-98.939,-96.939,-95.939,-94.939,-94.939,-94.939,-95.939,-92.939,-91.939,-90.939,-89.939,-86.939,-82.939,-78.939,-76.939,-71.939,-64.939,-56.939003,-35.939003,-12.939003,-25.939003,-34.939003,-37.939003,-29.939003,-22.939003,-15.939003,-12.939003,-10.939003,-7.939003,-2.939003,6.060997,8.060997,10.060997,15.060997,14.060997,13.060997,15.060997,17.060997,19.060997,11.060997,5.060997,1.060997,0.06099701,-1.939003,-10.939003,-18.939003,-25.939003,-27.939003,-33.939003,-47.939003,-56.939003,-63.939003,-73.939,-81.939,-88.939,3.060997,74.061,70.061,80.061,90.061,44.060997,-16.939003,-90.939,-95.939,-93.939,-91.939,-89.939,-87.939,-85.939,-81.939,-77.939,-70.939,-63.939003,-59.939003,-54.939003,-50.939003,-47.939003,-42.939003,-35.939003,-26.939003,-19.939003,-15.939003,-11.939003,-6.939003,-2.939003,4.060997,10.060997,15.060997,19.060997,21.060997,22.060997,24.060997,22.060997,22.060997,21.060997,16.060997,11.060997,10.060997,5.060997,-1.939003,-9.939003,-16.939003,-25.939003,-31.939003,-38.939003,-49.939003,-59.939003,-68.939,-73.939,-79.939,-88.939,-94.939,-98.939,-98.939,-98.939,-98.939,-96.939,-94.939,-93.939,-80.939,-36.939003,75.061,94.061,77.061,87.061,49.060997,-36.939003,-64.939,-74.939,-59.939003,-52.939003,-47.939003,-42.939003,-37.939003,-34.939003,-24.939003,-14.939003,-5.939003,-1.939003,0.06099701,4.060997,7.060997,10.060997,14.060997,18.060997,18.060997,20.060997,20.060997,21.060997,20.060997,16.060997,7.060997,0.06099701,1.060997,-2.939003,-9.939003,-18.939003,-26.939003,-34.939003,-44.939003,-53.939003,-59.939003,-69.939,-79.939,-87.939,-91.939,-91.939,-95.939,-98.939,-96.939,-94.939,-92.939,-91.939,-89.939,-87.939,-86.939,-85.939,-83.939,-80.939,-75.939,-39.939003,-19.939003,-16.939003,-6.939003,2.060997,10.060997,5.060997,-3.939003,-5.939003,-8.939003,-11.939003,-14.939003,-16.939003,-16.939003,-17.939003,-17.939003,-11.939003,-8.939003,-9.939003,-7.939003,-6.939003,-4.939003,-7.939003,-11.939003,-4.939003,3.060997,12.060997,15.060997,19.060997,27.060997,29.060997,31.060997,36.060997,41.060997,47.060997,52.060997,56.060997,63.060997,59.060997,50.060997,16.060997,36.060997,109.061,49.060997,-20.939003,-84.939,-92.939,-78.939,-78.939,-70.939,-56.939003,-51.939003,-53.939003,-63.939003,-83.939,-103.939,-102.939,-101.939,-98.939,-73.939,-51.939003,-38.939003,-32.939003,-27.939003,-18.939003,-9.939003,-1.939003,-0.939003,0.06099701,5.060997,10.060997,14.060997,14.060997,18.060997,24.060997,23.060997,22.060997,21.060997,19.060997,16.060997,10.060997,7.060997,3.060997,-5.939003,-14.939003,-18.939003,-18.939003,-21.939003,-59.939003,-86.939,-101.939,-97.939,-90.939,-81.939,-81.939,-82.939,-68.939,-65.939,-73.939,-90.939,-102.939,-100.939,-96.939,-92.939,-90.939,-46.939003,37.060997,43.060997,39.060997,30.060997,-0.939003,-38.939003,-46.939003,-45.939003,-34.939003,-27.939003,-20.939003,-15.939003,-10.939003,-4.939003,0.06099701,6.060997,14.060997,14.060997,16.060997,23.060997,22.060997,18.060997,18.060997,20.060997,21.060997,13.060997,7.060997,8.060997,6.060997,1.060997,-15.939003,-31.939003,-47.939003,-55.939003,-63.939003,-74.939,-83.939,-90.939,-90.939,-90.939,-91.939,-93.939,-92.939,-85.939,-81.939,-80.939,-80.939,-79.939,-78.939,-81.939,-82.939,-78.939,-79.939,-81.939,-81.939,-79.939,-75.939,-77.939,-78.939,-77.939,-78.939,-78.939,-77.939,-77.939,-79.939,-77.939,-76.939,-77.939,-77.939,-78.939,-74.939,-75.939,-79.939,-78.939,-78.939,-78.939,-77.939,-77.939,-77.939,-78.939,-81.939,-77.939,-75.939,-77.939,-78.939,-79.939,-77.939,-76.939,-75.939,-78.939,-79.939,-81.939,-79.939,-77.939,-77.939,-79.939,-80.939,-77.939,-75.939,-74.939,-73.939,-73.939,-75.939,-75.939,-72.939,-77.939,-79.939,-75.939,-74.939,-74.939,-75.939,-75.939,-74.939,-75.939,-76.939,-77.939,-78.939,-77.939,-78.939,-78.939,-78.939,-79.939,-79.939,-78.939,-77.939,-77.939,-75.939,-73.939,-72.939,-74.939,-75.939,-76.939,-74.939,-72.939,-75.939,-78.939,-78.939,-79.939,-80.939,-82.939,-83.939,-81.939,-74.939,-74.939,-79.939,-79.939,-79.939,-80.939,-81.939,-80.939,-79.939,-79.939,-81.939,-81.939,-81.939,-77.939,-78.939,-81.939,-77.939,-77.939,-81.939,-82.939,-80.939,-77.939,-80.939,-83.939,-80.939,-79.939,-78.939,-78.939,-79.939,-81.939,-82.939,-83.939,-81.939,-80.939,-80.939,-82.939,-83.939,-79.939,-81.939,-84.939,-77.939,-75.939,-80.939,-79.939,-79.939,-80.939,-83.939,-85.939,-82.939,-79.939,-76.939,-78.939,-80.939,-80.939,-80.939,-81.939,-80.939,-82.939,-85.939,-83.939,-81.939,-77.939,-77.939,-79.939,-81.939,-79.939,-76.939,-79.939,-82.939,-81.939,-78.939,-74.939,-77.939,-78.939,-79.939,-78.939,-76.939,-70.939,-68.939,-67.939,-68.939,-70.939,-70.939,-72.939,-73.939,-71.939,-71.939,-69.939,-72.939,-73.939,-72.939,-73.939,-74.939,-75.939,-78.939,-81.939,-78.939,-79.939,-84.939,-85.939,-86.939,-85.939,-85.939,-86.939,-88.939,-86.939,-81.939,-83.939,-84.939,-79.939,-78.939,-78.939,-79.939,-79.939,-80.939,-78.939,-77.939,-81.939,-79.939,-76.939,-83.939,-84.939,-80.939,-81.939,-83.939,-86.939,-85.939,-82.939,-82.939,-84.939,-88.939,-87.939,-86.939,-85.939,-86.939,-87.939,-85.939,-85.939,-88.939,-84.939,-82.939,-82.939,-83.939,-84.939,-85.939,-83.939,-82.939,-84.939,-83.939,-82.939,-81.939,-81.939,-84.939,-84.939,-81.939,-82.939,-82.939,-80.939,-79.939,-78.939,-83.939,-84.939,-83.939,-85.939,-86.939,-84.939,-83.939,-83.939,-82.939,-82.939,-81.939,-81.939,-82.939,-84.939,-83.939,-81.939,-84.939,-87.939,-90.939,-85.939,-82.939,-82.939,-85.939,-89.939,-85.939,-83.939,-83.939,-81.939,-81.939,-83.939,-84.939,-84.939,-83.939,-83.939,-81.939,-84.939,-86.939,-86.939,-88.939,-90.939,-89.939,-87.939,-84.939,-87.939,-89.939,-88.939,-87.939,-87.939,-90.939,-89.939,-85.939,-87.939,-87.939,-85.939,-80.939,-77.939,-80.939,-83.939,-85.939,-85.939,-85.939,-84.939,-84.939,-83.939,-85.939,-87.939,-88.939,-87.939,-87.939,-88.939,-89.939,-88.939,-88.939,-88.939,-87.939,-88.939,-89.939,-88.939,-88.939,-88.939,-90.939,-91.939,-90.939,-88.939,-87.939,-90.939,-92.939,-92.939,-89.939,-87.939,-88.939,-87.939,-87.939,-89.939,-89.939,-88.939,-89.939,-89.939,-90.939,-89.939,-88.939,-88.939,-87.939,-86.939,-89.939,-89.939,-87.939,-91.939,-94.939,-92.939,-94.939,-97.939,-100.939,-97.939,-92.939,-90.939,-90.939,-94.939,-95.939,-94.939,-99.939,-101.939,-100.939,-99.939,-97.939,-97.939,-96.939,-96.939,-95.939,-96.939,-98.939,-95.939,-94.939,-99.939,-101.939,-99.939,-81.939,-68.939,-61.939003,-57.939003,-57.939003,-69.939,-80.939,-88.939,-82.939,-74.939,-65.939,-63.939003,-67.939,-80.939,-84.939,-83.939,-81.939,-81.939,-84.939,-80.939,-75.939,-67.939,-59.939003,-53.939003,-49.939003,-46.939003,-43.939003,-46.939003,-48.939003,-49.939003,-50.939003,-51.939003,-53.939003,-49.939003,-40.939003,-51.939003,-60.939003,-60.939003,-59.939003,-58.939003,-63.939003,-64.939,-63.939003,-62.939003,-61.939003,-19.939003,-13.939003,-4.939003,-4.939003,-5.939003,-8.939003,4.060997,16.060997,23.060997,30.060997,36.060997,42.060997,46.060997,46.060997,54.060997,61.060997,62.060997,66.061,70.061,79.061,69.061,38.060997,-36.939003,-92.939,-93.939,-95.939,-98.939,-99.939,-97.939,-92.939,-88.939,-85.939,-85.939,-80.939,-74.939,-72.939,-69.939,-67.939,-61.939003,-53.939003,-48.939003,-40.939003,-34.939003,-30.939003,-26.939003,-21.939003,-12.939003,-6.939003,-5.939003,-7.939003,-8.939003,2.060997,10.060997,13.060997,18.060997,22.060997,23.060997,22.060997,20.060997,13.060997,9.060997,8.060997,-0.939003,-9.939003,-14.939003,-24.939003,-35.939003,-41.939003,-47.939003,-52.939003,-59.939003,-68.939,-76.939,-79.939,-80.939,-81.939,-83.939,-85.939,-86.939,-87.939,-89.939,-91.939,-94.939,-94.939,-96.939,-100.939,-8.939003,65.061,68.061,69.061,65.061,33.060997,-13.939003,-75.939,-70.939,-59.939003,-51.939003,-43.939003,-34.939003,-24.939003,-16.939003,-9.939003,-1.939003,4.060997,9.060997,11.060997,13.060997,16.060997,18.060997,20.060997,20.060997,19.060997,22.060997,20.060997,18.060997,16.060997,12.060997,9.060997,4.060997,-3.939003,-17.939003,-26.939003,-32.939003,-35.939003,-41.939003,-50.939003,-58.939003,-66.939,-74.939,-78.939,-80.939,-81.939,-83.939,-85.939,-86.939,-88.939,-91.939,-93.939,-95.939,-95.939,-96.939,-98.939,-97.939,-95.939,-88.939,-84.939,-80.939,-71.939,-65.939,-60.939003,-49.939003,-27.939003,20.060997,25.060997,12.060997,10.060997,2.060997,-11.939003,-2.939003,6.060997,11.060997,16.060997,21.060997,20.060997,21.060997,26.060997,28.060997,28.060997,23.060997,18.060997,12.060997,8.060997,2.060997,-4.939003,-7.939003,-14.939003,-31.939003,-41.939003,-47.939003,-50.939003,-56.939003,-64.939,-73.939,-79.939,-79.939,-80.939,-81.939,-83.939,-85.939,-87.939,-89.939,-91.939,-93.939,-95.939,-98.939,-97.939,-93.939,-86.939,-84.939,-79.939,-72.939,-63.939003,-55.939003,-51.939003,-44.939003,-36.939003,-31.939003,-25.939003,-16.939003,-9.939003,-3.939003,-1.939003,1.060997,3.060997,0.06099701,-6.939003,-21.939003,-26.939003,-25.939003,-22.939003,-18.939003,-15.939003,-11.939003,-7.939003,-3.939003,-1.939003,1.060997,3.060997,6.060997,9.060997,7.060997,12.060997,28.060997,44.060997,57.060997,54.060997,54.060997,59.060997,66.061,74.061,79.061,81.061,80.061,81.061,80.061,78.061,78.061,78.061,72.061,68.061,58.060997,14.060997,20.060997,74.061,30.060997,-25.939003,-85.939,-97.939,-90.939,-78.939,-72.939,-71.939,-62.939003,-59.939003,-68.939,-85.939,-103.939,-103.939,-102.939,-101.939,-35.939003,18.060997,23.060997,23.060997,20.060997,22.060997,23.060997,23.060997,18.060997,12.060997,4.060997,-2.939003,-10.939003,-19.939003,-26.939003,-33.939003,-39.939003,-46.939003,-53.939003,-61.939003,-68.939,-74.939,-77.939,-78.939,-80.939,-82.939,-83.939,-83.939,-84.939,-93.939,-99.939,-103.939,-99.939,-94.939,-83.939,-81.939,-82.939,-73.939,-73.939,-81.939,-94.939,-103.939,-101.939,-97.939,-90.939,-56.939003,-28.939003,-4.939003,-11.939003,-18.939003,-20.939003,-5.939003,12.060997,18.060997,22.060997,23.060997,23.060997,21.060997,19.060997,18.060997,19.060997,11.060997,7.060997,5.060997,-0.939003,-7.939003,-15.939003,-24.939003,-33.939003,-40.939003,-44.939003,-47.939003,-54.939003,-61.939003,-65.939,-67.939,-68.939,-76.939,-83.939,-87.939,-84.939,-82.939,-82.939,-83.939,-85.939,-81.939,-78.939,-76.939,-78.939,-79.939,-76.939,-77.939,-79.939,-78.939,-77.939,-74.939,-77.939,-79.939,-78.939,-78.939,-80.939,-80.939,-80.939,-78.939,-79.939,-81.939,-82.939,-80.939,-77.939,-77.939,-77.939,-78.939,-76.939,-75.939,-76.939,-78.939,-79.939,-75.939,-74.939,-76.939,-77.939,-76.939,-75.939,-76.939,-76.939,-76.939,-77.939,-79.939,-76.939,-74.939,-76.939,-77.939,-78.939,-78.939,-79.939,-79.939,-79.939,-80.939,-79.939,-79.939,-78.939,-79.939,-79.939,-78.939,-77.939,-76.939,-75.939,-75.939,-74.939,-74.939,-70.939,-66.939,-72.939,-78.939,-76.939,-74.939,-73.939,-74.939,-76.939,-79.939,-77.939,-73.939,-72.939,-73.939,-76.939,-76.939,-77.939,-77.939,-76.939,-76.939,-77.939,-77.939,-78.939,-76.939,-75.939,-72.939,-73.939,-74.939,-73.939,-73.939,-71.939,-74.939,-76.939,-76.939,-76.939,-77.939,-78.939,-78.939,-77.939,-79.939,-79.939,-78.939,-77.939,-77.939,-79.939,-80.939,-82.939,-79.939,-79.939,-82.939,-80.939,-77.939,-75.939,-76.939,-78.939,-78.939,-79.939,-80.939,-77.939,-75.939,-76.939,-79.939,-84.939,-80.939,-78.939,-77.939,-78.939,-78.939,-78.939,-77.939,-77.939,-78.939,-79.939,-81.939,-81.939,-81.939,-79.939,-80.939,-82.939,-78.939,-78.939,-81.939,-81.939,-80.939,-79.939,-81.939,-83.939,-82.939,-81.939,-78.939,-80.939,-82.939,-80.939,-81.939,-82.939,-81.939,-81.939,-83.939,-82.939,-82.939,-80.939,-82.939,-85.939,-83.939,-81.939,-78.939,-78.939,-77.939,-76.939,-74.939,-73.939,-74.939,-73.939,-70.939,-69.939,-67.939,-65.939,-63.939003,-62.939003,-63.939003,-65.939,-68.939,-68.939,-68.939,-66.939,-66.939,-65.939,-67.939,-68.939,-67.939,-68.939,-68.939,-67.939,-68.939,-70.939,-73.939,-78.939,-82.939,-84.939,-86.939,-86.939,-87.939,-88.939,-89.939,-87.939,-83.939,-83.939,-82.939,-80.939,-80.939,-80.939,-80.939,-81.939,-80.939,-79.939,-79.939,-81.939,-80.939,-77.939,-81.939,-82.939,-77.939,-79.939,-82.939,-85.939,-85.939,-84.939,-84.939,-85.939,-87.939,-86.939,-86.939,-86.939,-86.939,-86.939,-83.939,-84.939,-89.939,-86.939,-81.939,-79.939,-81.939,-84.939,-83.939,-82.939,-81.939,-83.939,-84.939,-83.939,-83.939,-85.939,-82.939,-79.939,-74.939,-79.939,-82.939,-83.939,-82.939,-81.939,-84.939,-86.939,-85.939,-86.939,-87.939,-86.939,-84.939,-81.939,-83.939,-83.939,-84.939,-83.939,-80.939,-79.939,-80.939,-82.939,-84.939,-86.939,-88.939,-87.939,-85.939,-83.939,-82.939,-82.939,-80.939,-82.939,-85.939,-83.939,-81.939,-82.939,-82.939,-83.939,-85.939,-85.939,-86.939,-87.939,-88.939,-87.939,-88.939,-91.939,-90.939,-87.939,-81.939,-83.939,-87.939,-90.939,-90.939,-89.939,-88.939,-87.939,-87.939,-86.939,-85.939,-83.939,-82.939,-82.939,-83.939,-84.939,-85.939,-83.939,-82.939,-83.939,-83.939,-82.939,-83.939,-85.939,-89.939,-86.939,-83.939,-83.939,-84.939,-87.939,-89.939,-90.939,-89.939,-88.939,-87.939,-87.939,-87.939,-86.939,-86.939,-87.939,-89.939,-89.939,-88.939,-88.939,-87.939,-87.939,-88.939,-88.939,-86.939,-88.939,-91.939,-92.939,-91.939,-89.939,-88.939,-88.939,-87.939,-86.939,-86.939,-87.939,-87.939,-87.939,-87.939,-88.939,-89.939,-92.939,-92.939,-87.939,-88.939,-91.939,-91.939,-92.939,-94.939,-93.939,-92.939,-93.939,-94.939,-93.939,-94.939,-94.939,-93.939,-96.939,-99.939,-100.939,-99.939,-97.939,-96.939,-97.939,-97.939,-95.939,-96.939,-101.939,-101.939,-98.939,-80.939,-67.939,-59.939003,-49.939003,-49.939003,-67.939,-77.939,-82.939,-71.939,-65.939,-62.939003,-65.939,-69.939,-72.939,-80.939,-88.939,-84.939,-80.939,-76.939,-75.939,-73.939,-63.939003,-54.939003,-48.939003,-48.939003,-48.939003,-49.939003,-49.939003,-49.939003,-51.939003,-52.939003,-55.939003,-53.939003,-48.939003,-37.939003,-46.939003,-55.939003,-55.939003,-55.939003,-55.939003,-58.939003,-59.939003,-58.939003,-61.939003,-63.939003,19.060997,24.060997,30.060997,30.060997,18.060997,-6.939003,22.060997,51.060997,66.061,69.061,68.061,73.061,75.061,75.061,78.061,81.061,77.061,78.061,78.061,82.061,68.061,33.060997,-33.939003,-83.939,-83.939,-80.939,-76.939,-74.939,-69.939,-60.939003,-50.939003,-44.939003,-45.939003,-39.939003,-32.939003,-28.939003,-24.939003,-20.939003,-14.939003,-8.939003,-5.939003,-0.939003,4.060997,4.060997,5.060997,7.060997,8.060997,7.060997,0.06099701,-3.939003,-6.939003,1.060997,0.06099701,-6.939003,-7.939003,-9.939003,-14.939003,-19.939003,-25.939003,-31.939003,-35.939003,-37.939003,-45.939003,-52.939003,-56.939003,-64.939,-73.939,-78.939,-82.939,-85.939,-90.939,-95.939,-98.939,-96.939,-93.939,-90.939,-88.939,-87.939,-84.939,-80.939,-74.939,-72.939,-71.939,-68.939,-64.939,-60.939003,-11.939003,26.060997,25.060997,22.060997,16.060997,7.060997,-6.939003,-25.939003,-19.939003,-10.939003,-3.939003,2.060997,7.060997,12.060997,15.060997,16.060997,17.060997,17.060997,15.060997,11.060997,7.060997,7.060997,3.060997,-1.939003,-7.939003,-13.939003,-16.939003,-21.939003,-26.939003,-31.939003,-36.939003,-40.939003,-44.939003,-50.939003,-61.939003,-68.939,-73.939,-75.939,-79.939,-85.939,-90.939,-94.939,-101.939,-100.939,-96.939,-93.939,-90.939,-86.939,-84.939,-82.939,-78.939,-73.939,-69.939,-65.939,-62.939003,-59.939003,-54.939003,-48.939003,-39.939003,-34.939003,-29.939003,-20.939003,-14.939003,-10.939003,-3.939003,1.060997,0.06099701,-2.939003,-6.939003,-8.939003,-6.939003,0.06099701,7.060997,10.060997,8.060997,8.060997,8.060997,0.06099701,-5.939003,-7.939003,-11.939003,-16.939003,-25.939003,-31.939003,-35.939003,-39.939003,-44.939003,-50.939003,-53.939003,-59.939003,-71.939,-78.939,-83.939,-86.939,-89.939,-94.939,-95.939,-95.939,-94.939,-92.939,-89.939,-85.939,-82.939,-79.939,-75.939,-71.939,-68.939,-64.939,-59.939003,-54.939003,-48.939003,-39.939003,-34.939003,-27.939003,-19.939003,-12.939003,-7.939003,-4.939003,-0.939003,4.060997,8.060997,11.060997,12.060997,13.060997,12.060997,5.060997,1.060997,-0.939003,-0.939003,-0.939003,-3.939003,-1.939003,2.060997,8.060997,15.060997,21.060997,27.060997,31.060997,34.060997,43.060997,53.060997,31.060997,19.060997,15.060997,24.060997,35.060997,46.060997,64.061,82.061,69.061,62.060997,59.060997,64.061,68.061,66.061,63.060997,60.060997,57.060997,52.060997,45.060997,44.060997,42.060997,33.060997,28.060997,23.060997,0.06099701,1.060997,24.060997,0.06099701,-35.939003,-87.939,-99.939,-94.939,-80.939,-74.939,-75.939,-69.939,-68.939,-73.939,-87.939,-101.939,-100.939,-101.939,-102.939,-48.939003,-5.939003,-1.939003,-7.939003,-16.939003,-19.939003,-23.939003,-27.939003,-31.939003,-35.939003,-42.939003,-48.939003,-55.939003,-61.939003,-68.939,-74.939,-78.939,-82.939,-87.939,-91.939,-94.939,-95.939,-93.939,-89.939,-84.939,-79.939,-79.939,-71.939,-61.939003,-71.939,-84.939,-101.939,-99.939,-93.939,-83.939,-81.939,-82.939,-74.939,-74.939,-82.939,-95.939,-103.939,-101.939,-97.939,-88.939,-34.939003,-6.939003,-3.939003,-11.939003,-15.939003,-7.939003,-0.939003,4.060997,-0.939003,-3.939003,-5.939003,-13.939003,-21.939003,-27.939003,-29.939003,-30.939003,-36.939003,-41.939003,-43.939003,-48.939003,-53.939003,-60.939003,-66.939,-73.939,-77.939,-79.939,-79.939,-80.939,-82.939,-87.939,-87.939,-85.939,-86.939,-88.939,-89.939,-86.939,-82.939,-80.939,-79.939,-78.939,-75.939,-74.939,-75.939,-77.939,-77.939,-76.939,-77.939,-79.939,-78.939,-76.939,-75.939,-77.939,-78.939,-78.939,-78.939,-78.939,-79.939,-79.939,-79.939,-79.939,-79.939,-79.939,-80.939,-79.939,-78.939,-77.939,-77.939,-76.939,-76.939,-77.939,-78.939,-78.939,-74.939,-73.939,-75.939,-76.939,-76.939,-74.939,-75.939,-76.939,-73.939,-73.939,-77.939,-76.939,-76.939,-75.939,-76.939,-78.939,-78.939,-79.939,-79.939,-78.939,-77.939,-75.939,-76.939,-78.939,-79.939,-79.939,-78.939,-79.939,-78.939,-76.939,-75.939,-74.939,-74.939,-72.939,-67.939,-73.939,-77.939,-78.939,-76.939,-74.939,-73.939,-74.939,-77.939,-76.939,-73.939,-71.939,-72.939,-74.939,-76.939,-77.939,-76.939,-76.939,-77.939,-78.939,-77.939,-76.939,-75.939,-74.939,-72.939,-73.939,-73.939,-73.939,-74.939,-74.939,-76.939,-76.939,-74.939,-75.939,-76.939,-77.939,-77.939,-76.939,-80.939,-80.939,-76.939,-75.939,-76.939,-80.939,-80.939,-80.939,-78.939,-78.939,-80.939,-78.939,-76.939,-77.939,-77.939,-77.939,-77.939,-77.939,-77.939,-77.939,-76.939,-75.939,-77.939,-82.939,-80.939,-79.939,-77.939,-79.939,-79.939,-77.939,-76.939,-75.939,-79.939,-81.939,-81.939,-80.939,-80.939,-79.939,-80.939,-81.939,-80.939,-80.939,-82.939,-81.939,-81.939,-80.939,-81.939,-83.939,-82.939,-81.939,-79.939,-81.939,-82.939,-80.939,-81.939,-82.939,-81.939,-80.939,-81.939,-81.939,-81.939,-79.939,-81.939,-83.939,-80.939,-77.939,-77.939,-75.939,-73.939,-71.939,-69.939,-68.939,-69.939,-67.939,-64.939,-63.939003,-62.939003,-62.939003,-59.939003,-57.939003,-59.939003,-62.939003,-66.939,-64.939,-63.939003,-64.939,-65.939,-64.939,-64.939,-64.939,-62.939003,-65.939,-66.939,-65.939,-66.939,-68.939,-72.939,-76.939,-78.939,-80.939,-82.939,-85.939,-86.939,-86.939,-87.939,-87.939,-84.939,-84.939,-84.939,-83.939,-81.939,-79.939,-80.939,-80.939,-79.939,-79.939,-79.939,-81.939,-81.939,-79.939,-83.939,-82.939,-78.939,-79.939,-81.939,-83.939,-83.939,-82.939,-83.939,-84.939,-85.939,-84.939,-85.939,-85.939,-86.939,-86.939,-84.939,-86.939,-90.939,-87.939,-83.939,-80.939,-81.939,-84.939,-83.939,-81.939,-82.939,-84.939,-84.939,-81.939,-83.939,-86.939,-81.939,-77.939,-72.939,-78.939,-82.939,-83.939,-83.939,-83.939,-84.939,-85.939,-84.939,-85.939,-86.939,-85.939,-82.939,-80.939,-83.939,-84.939,-83.939,-81.939,-79.939,-78.939,-78.939,-80.939,-81.939,-83.939,-84.939,-84.939,-85.939,-83.939,-81.939,-79.939,-79.939,-81.939,-85.939,-83.939,-81.939,-82.939,-82.939,-83.939,-85.939,-86.939,-86.939,-86.939,-85.939,-83.939,-84.939,-87.939,-88.939,-88.939,-83.939,-84.939,-85.939,-88.939,-88.939,-87.939,-86.939,-85.939,-86.939,-85.939,-84.939,-83.939,-82.939,-82.939,-83.939,-83.939,-83.939,-81.939,-80.939,-83.939,-83.939,-82.939,-82.939,-84.939,-87.939,-85.939,-83.939,-83.939,-84.939,-86.939,-88.939,-88.939,-86.939,-86.939,-86.939,-86.939,-87.939,-89.939,-86.939,-86.939,-86.939,-88.939,-88.939,-86.939,-85.939,-85.939,-88.939,-89.939,-86.939,-89.939,-91.939,-92.939,-90.939,-87.939,-87.939,-87.939,-86.939,-85.939,-85.939,-87.939,-87.939,-87.939,-87.939,-88.939,-89.939,-91.939,-91.939,-87.939,-88.939,-90.939,-88.939,-89.939,-94.939,-94.939,-93.939,-92.939,-93.939,-94.939,-93.939,-92.939,-91.939,-95.939,-98.939,-98.939,-96.939,-94.939,-93.939,-94.939,-95.939,-95.939,-97.939,-101.939,-102.939,-99.939,-84.939,-70.939,-58.939003,-47.939003,-46.939003,-69.939,-79.939,-82.939,-71.939,-65.939,-62.939003,-66.939,-70.939,-70.939,-73.939,-78.939,-79.939,-80.939,-81.939,-75.939,-70.939,-63.939003,-58.939003,-54.939003,-54.939003,-53.939003,-51.939003,-53.939003,-55.939003,-52.939003,-53.939003,-56.939003,-49.939003,-45.939003,-40.939003,-47.939003,-53.939003,-51.939003,-51.939003,-53.939003,-55.939003,-56.939003,-56.939003,-61.939003,-65.939,74.061,72.061,69.061,72.061,47.060997,-4.939003,36.060997,77.061,98.061,96.061,85.061,85.061,86.061,87.061,83.061,79.061,70.061,67.061,65.061,61.060997,44.060997,13.060997,-26.939003,-53.939003,-51.939003,-45.939003,-37.939003,-32.939003,-25.939003,-14.939003,-2.939003,4.060997,1.060997,6.060997,13.060997,17.060997,21.060997,25.060997,28.060997,29.060997,26.060997,25.060997,27.060997,22.060997,18.060997,16.060997,7.060997,-4.939003,-17.939003,-13.939003,-6.939003,-14.939003,-30.939003,-56.939003,-62.939003,-68.939,-78.939,-87.939,-94.939,-97.939,-97.939,-94.939,-96.939,-97.939,-95.939,-97.939,-99.939,-98.939,-97.939,-96.939,-96.939,-95.939,-89.939,-82.939,-74.939,-67.939,-63.939003,-61.939003,-54.939003,-44.939003,-34.939003,-28.939003,-25.939003,-21.939003,-14.939003,-3.939003,-9.939003,-17.939003,-25.939003,-30.939003,-32.939003,-20.939003,1.060997,29.060997,30.060997,32.060997,36.060997,34.060997,32.060997,28.060997,22.060997,15.060997,9.060997,1.060997,-7.939003,-19.939003,-29.939003,-34.939003,-42.939003,-55.939003,-64.939,-73.939,-81.939,-88.939,-94.939,-99.939,-101.939,-101.939,-100.939,-99.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-96.939,-95.939,-96.939,-89.939,-79.939,-74.939,-66.939,-58.939003,-54.939003,-49.939003,-39.939003,-30.939003,-20.939003,-14.939003,-9.939003,-4.939003,2.060997,9.060997,17.060997,21.060997,25.060997,32.060997,35.060997,35.060997,36.060997,28.060997,-0.939003,-4.939003,1.060997,3.060997,5.060997,7.060997,-6.939003,-20.939003,-26.939003,-33.939003,-39.939003,-53.939003,-65.939,-76.939,-83.939,-92.939,-100.939,-102.939,-102.939,-101.939,-101.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-95.939,-94.939,-86.939,-78.939,-75.939,-70.939,-63.939003,-57.939003,-49.939003,-42.939003,-34.939003,-27.939003,-20.939003,-11.939003,-0.939003,5.060997,11.060997,17.060997,23.060997,29.060997,34.060997,35.060997,33.060997,33.060997,32.060997,27.060997,29.060997,26.060997,15.060997,7.060997,-1.939003,-4.939003,-9.939003,-15.939003,-4.939003,11.060997,33.060997,44.060997,49.060997,55.060997,62.060997,68.061,73.061,76.061,74.061,91.061,107.061,59.060997,28.060997,14.060997,37.060997,54.060997,51.060997,65.061,81.061,60.060997,46.060997,38.060997,39.060997,36.060997,25.060997,18.060997,14.060997,7.060997,-0.939003,-9.939003,-9.939003,-11.939003,-18.939003,-22.939003,-22.939003,-15.939003,-15.939003,-22.939003,-25.939003,-41.939003,-89.939,-100.939,-96.939,-84.939,-77.939,-75.939,-76.939,-77.939,-78.939,-88.939,-99.939,-97.939,-98.939,-103.939,-81.939,-65.939,-61.939003,-70.939,-85.939,-90.939,-95.939,-100.939,-101.939,-100.939,-100.939,-99.939,-98.939,-98.939,-98.939,-97.939,-96.939,-95.939,-95.939,-91.939,-85.939,-80.939,-73.939,-65.939,-53.939003,-43.939003,-42.939003,-25.939003,-4.939003,-26.939003,-57.939003,-99.939,-97.939,-90.939,-82.939,-82.939,-83.939,-72.939,-72.939,-81.939,-94.939,-103.939,-101.939,-97.939,-87.939,-23.939003,10.060997,15.060997,14.060997,17.060997,31.060997,8.060997,-27.939003,-51.939003,-64.939,-68.939,-82.939,-93.939,-100.939,-101.939,-99.939,-99.939,-99.939,-100.939,-101.939,-101.939,-101.939,-101.939,-101.939,-99.939,-96.939,-92.939,-86.939,-83.939,-87.939,-84.939,-80.939,-76.939,-75.939,-77.939,-76.939,-75.939,-75.939,-74.939,-72.939,-71.939,-73.939,-79.939,-80.939,-80.939,-78.939,-78.939,-78.939,-77.939,-76.939,-77.939,-78.939,-78.939,-78.939,-77.939,-76.939,-77.939,-78.939,-80.939,-78.939,-76.939,-75.939,-78.939,-81.939,-79.939,-77.939,-75.939,-77.939,-78.939,-79.939,-77.939,-75.939,-72.939,-72.939,-74.939,-76.939,-77.939,-75.939,-74.939,-75.939,-70.939,-70.939,-74.939,-76.939,-78.939,-76.939,-77.939,-78.939,-77.939,-77.939,-77.939,-75.939,-73.939,-71.939,-73.939,-76.939,-78.939,-79.939,-79.939,-81.939,-81.939,-77.939,-74.939,-72.939,-75.939,-75.939,-71.939,-75.939,-78.939,-79.939,-79.939,-77.939,-73.939,-72.939,-73.939,-74.939,-74.939,-72.939,-72.939,-73.939,-75.939,-76.939,-77.939,-78.939,-79.939,-79.939,-77.939,-73.939,-73.939,-73.939,-73.939,-74.939,-74.939,-75.939,-76.939,-77.939,-78.939,-77.939,-73.939,-75.939,-77.939,-76.939,-76.939,-75.939,-80.939,-80.939,-75.939,-74.939,-75.939,-82.939,-81.939,-77.939,-76.939,-77.939,-78.939,-76.939,-75.939,-80.939,-80.939,-77.939,-76.939,-75.939,-74.939,-78.939,-79.939,-74.939,-75.939,-80.939,-81.939,-81.939,-78.939,-80.939,-81.939,-78.939,-75.939,-74.939,-81.939,-82.939,-80.939,-79.939,-79.939,-80.939,-80.939,-79.939,-81.939,-82.939,-81.939,-82.939,-83.939,-81.939,-81.939,-82.939,-83.939,-82.939,-81.939,-82.939,-82.939,-81.939,-82.939,-83.939,-81.939,-79.939,-78.939,-79.939,-79.939,-77.939,-77.939,-76.939,-73.939,-72.939,-73.939,-72.939,-70.939,-67.939,-65.939,-64.939,-65.939,-63.939003,-59.939003,-59.939003,-59.939003,-59.939003,-56.939003,-53.939003,-56.939003,-59.939003,-64.939,-61.939003,-59.939003,-62.939003,-64.939,-65.939,-63.939003,-61.939003,-59.939003,-63.939003,-66.939,-66.939,-67.939,-69.939,-72.939,-74.939,-75.939,-76.939,-77.939,-82.939,-84.939,-83.939,-85.939,-86.939,-86.939,-86.939,-86.939,-86.939,-82.939,-77.939,-78.939,-78.939,-77.939,-79.939,-80.939,-82.939,-82.939,-82.939,-85.939,-84.939,-81.939,-80.939,-80.939,-81.939,-81.939,-80.939,-81.939,-82.939,-82.939,-82.939,-83.939,-84.939,-85.939,-85.939,-86.939,-88.939,-90.939,-88.939,-85.939,-82.939,-83.939,-85.939,-82.939,-82.939,-84.939,-85.939,-84.939,-79.939,-82.939,-86.939,-81.939,-76.939,-73.939,-78.939,-82.939,-82.939,-83.939,-84.939,-83.939,-83.939,-82.939,-83.939,-84.939,-82.939,-80.939,-78.939,-82.939,-83.939,-80.939,-79.939,-78.939,-77.939,-77.939,-77.939,-78.939,-79.939,-78.939,-81.939,-84.939,-84.939,-81.939,-78.939,-79.939,-81.939,-86.939,-84.939,-82.939,-81.939,-81.939,-82.939,-84.939,-85.939,-84.939,-83.939,-81.939,-78.939,-78.939,-81.939,-86.939,-89.939,-88.939,-86.939,-84.939,-85.939,-85.939,-84.939,-83.939,-83.939,-85.939,-84.939,-83.939,-82.939,-82.939,-81.939,-81.939,-81.939,-81.939,-79.939,-79.939,-84.939,-84.939,-82.939,-81.939,-82.939,-85.939,-84.939,-84.939,-86.939,-86.939,-87.939,-86.939,-85.939,-83.939,-84.939,-85.939,-87.939,-89.939,-92.939,-87.939,-85.939,-84.939,-86.939,-87.939,-84.939,-83.939,-84.939,-89.939,-90.939,-87.939,-88.939,-89.939,-90.939,-88.939,-85.939,-86.939,-86.939,-87.939,-85.939,-84.939,-86.939,-86.939,-86.939,-86.939,-87.939,-90.939,-90.939,-90.939,-88.939,-90.939,-91.939,-86.939,-87.939,-93.939,-93.939,-93.939,-91.939,-93.939,-96.939,-93.939,-91.939,-90.939,-94.939,-97.939,-95.939,-92.939,-90.939,-89.939,-91.939,-94.939,-96.939,-98.939,-102.939,-103.939,-102.939,-90.939,-75.939,-58.939003,-48.939003,-48.939003,-73.939,-82.939,-84.939,-75.939,-68.939,-64.939,-68.939,-71.939,-69.939,-67.939,-64.939,-72.939,-80.939,-88.939,-77.939,-67.939,-66.939,-65.939,-64.939,-63.939003,-58.939003,-51.939003,-57.939003,-62.939003,-53.939003,-54.939003,-56.939003,-45.939003,-42.939003,-46.939003,-50.939003,-52.939003,-47.939003,-48.939003,-51.939003,-52.939003,-53.939003,-54.939003,-60.939003,-65.939,78.061,77.061,74.061,71.061,45.060997,-3.939003,21.060997,50.060997,66.061,63.060997,51.060997,43.060997,37.060997,33.060997,24.060997,18.060997,11.060997,8.060997,5.060997,-1.939003,-8.939003,-15.939003,-13.939003,-7.939003,3.060997,9.060997,11.060997,12.060997,13.060997,14.060997,14.060997,12.060997,4.060997,2.060997,2.060997,-3.939003,-10.939003,-18.939003,-24.939003,-30.939003,-38.939003,-42.939003,-44.939003,-50.939003,-56.939003,-62.939003,-66.939,-70.939,-74.939,-48.939003,-16.939003,-29.939003,-52.939003,-85.939,-89.939,-90.939,-89.939,-88.939,-89.939,-86.939,-82.939,-76.939,-75.939,-72.939,-66.939,-63.939003,-60.939003,-53.939003,-48.939003,-45.939003,-39.939003,-33.939003,-26.939003,-21.939003,-16.939003,-8.939003,-1.939003,2.060997,7.060997,12.060997,14.060997,18.060997,21.060997,19.060997,17.060997,17.060997,7.060997,0.06099701,3.060997,2.060997,1.060997,2.060997,-3.939003,-19.939003,-26.939003,-30.939003,-34.939003,-40.939003,-47.939003,-54.939003,-60.939003,-63.939003,-65.939,-68.939,-71.939,-75.939,-78.939,-80.939,-83.939,-87.939,-89.939,-92.939,-94.939,-93.939,-89.939,-89.939,-86.939,-83.939,-79.939,-74.939,-71.939,-64.939,-56.939003,-51.939003,-45.939003,-40.939003,-36.939003,-32.939003,-28.939003,-20.939003,-11.939003,-5.939003,0.06099701,6.060997,9.060997,13.060997,20.060997,23.060997,24.060997,21.060997,18.060997,15.060997,12.060997,9.060997,4.060997,-1.939003,-5.939003,-9.939003,-15.939003,-23.939003,-35.939003,-27.939003,26.060997,50.060997,62.060997,77.061,61.060997,14.060997,-37.939003,-75.939,-77.939,-80.939,-82.939,-86.939,-90.939,-94.939,-96.939,-97.939,-97.939,-93.939,-89.939,-82.939,-76.939,-71.939,-64.939,-58.939003,-52.939003,-47.939003,-42.939003,-36.939003,-30.939003,-24.939003,-15.939003,-7.939003,-3.939003,0.06099701,3.060997,6.060997,10.060997,14.060997,14.060997,14.060997,16.060997,17.060997,16.060997,11.060997,7.060997,4.060997,-0.939003,-5.939003,-11.939003,-19.939003,-30.939003,-36.939003,-42.939003,-51.939003,-55.939003,-60.939003,-63.939003,-66.939,-66.939,-43.939003,-29.939003,-23.939003,-7.939003,16.060997,52.060997,65.061,67.061,62.060997,57.060997,52.060997,54.060997,54.060997,49.060997,50.060997,51.060997,28.060997,9.060997,-6.939003,2.060997,9.060997,6.060997,4.060997,4.060997,0.06099701,-1.939003,-2.939003,-3.939003,-4.939003,-5.939003,-6.939003,-8.939003,-10.939003,-10.939003,-10.939003,-7.939003,-4.939003,-3.939003,-1.939003,2.060997,1.060997,4.060997,10.060997,17.060997,0.06099701,-75.939,-99.939,-99.939,-91.939,-85.939,-80.939,-83.939,-84.939,-83.939,-91.939,-102.939,-100.939,-101.939,-102.939,-96.939,-90.939,-89.939,-91.939,-95.939,-92.939,-88.939,-83.939,-79.939,-73.939,-66.939,-58.939003,-52.939003,-50.939003,-45.939003,-38.939003,-28.939003,-20.939003,-17.939003,-12.939003,-6.939003,-2.939003,1.060997,5.060997,11.060997,15.060997,12.060997,21.060997,31.060997,5.060997,-37.939003,-97.939,-96.939,-90.939,-87.939,-87.939,-88.939,-77.939,-77.939,-85.939,-95.939,-102.939,-100.939,-100.939,-98.939,-76.939,-32.939003,34.060997,55.060997,67.061,70.061,18.060997,-51.939003,-77.939,-89.939,-88.939,-88.939,-87.939,-88.939,-76.939,-60.939003,-56.939003,-59.939003,-71.939,-78.939,-81.939,-82.939,-82.939,-82.939,-81.939,-81.939,-83.939,-81.939,-80.939,-82.939,-81.939,-79.939,-79.939,-79.939,-79.939,-78.939,-76.939,-74.939,-73.939,-73.939,-72.939,-74.939,-80.939,-79.939,-77.939,-75.939,-76.939,-77.939,-78.939,-79.939,-77.939,-78.939,-79.939,-77.939,-76.939,-75.939,-77.939,-78.939,-79.939,-77.939,-75.939,-74.939,-76.939,-78.939,-80.939,-80.939,-78.939,-78.939,-77.939,-77.939,-77.939,-77.939,-76.939,-75.939,-73.939,-75.939,-76.939,-74.939,-74.939,-75.939,-73.939,-73.939,-77.939,-78.939,-78.939,-75.939,-75.939,-76.939,-76.939,-76.939,-77.939,-76.939,-75.939,-75.939,-76.939,-78.939,-79.939,-78.939,-78.939,-78.939,-78.939,-76.939,-72.939,-68.939,-72.939,-74.939,-74.939,-74.939,-75.939,-74.939,-75.939,-76.939,-72.939,-72.939,-75.939,-76.939,-77.939,-76.939,-75.939,-73.939,-76.939,-78.939,-80.939,-79.939,-78.939,-74.939,-74.939,-77.939,-75.939,-74.939,-72.939,-72.939,-71.939,-72.939,-74.939,-76.939,-77.939,-77.939,-74.939,-76.939,-78.939,-76.939,-75.939,-74.939,-77.939,-77.939,-72.939,-73.939,-75.939,-81.939,-79.939,-74.939,-75.939,-76.939,-78.939,-75.939,-74.939,-79.939,-79.939,-78.939,-79.939,-78.939,-76.939,-77.939,-78.939,-78.939,-78.939,-79.939,-80.939,-79.939,-76.939,-79.939,-82.939,-80.939,-76.939,-73.939,-77.939,-77.939,-76.939,-76.939,-78.939,-79.939,-79.939,-77.939,-79.939,-79.939,-78.939,-79.939,-81.939,-80.939,-80.939,-79.939,-80.939,-81.939,-80.939,-81.939,-82.939,-82.939,-85.939,-87.939,-85.939,-82.939,-79.939,-81.939,-81.939,-77.939,-74.939,-70.939,-68.939,-67.939,-67.939,-69.939,-70.939,-66.939,-65.939,-65.939,-66.939,-64.939,-59.939003,-59.939003,-60.939003,-62.939003,-60.939003,-57.939003,-56.939003,-58.939003,-63.939003,-62.939003,-62.939003,-63.939003,-63.939003,-62.939003,-60.939003,-61.939003,-62.939003,-66.939,-69.939,-67.939,-67.939,-68.939,-73.939,-75.939,-73.939,-74.939,-76.939,-79.939,-82.939,-83.939,-86.939,-87.939,-87.939,-85.939,-83.939,-81.939,-80.939,-79.939,-79.939,-78.939,-76.939,-78.939,-81.939,-80.939,-80.939,-81.939,-82.939,-83.939,-81.939,-83.939,-85.939,-84.939,-83.939,-81.939,-84.939,-85.939,-85.939,-83.939,-83.939,-84.939,-84.939,-84.939,-86.939,-87.939,-87.939,-85.939,-84.939,-85.939,-85.939,-86.939,-82.939,-81.939,-83.939,-84.939,-83.939,-80.939,-81.939,-83.939,-82.939,-81.939,-80.939,-81.939,-81.939,-79.939,-81.939,-84.939,-83.939,-84.939,-85.939,-88.939,-89.939,-87.939,-83.939,-79.939,-82.939,-82.939,-79.939,-78.939,-77.939,-76.939,-75.939,-76.939,-77.939,-79.939,-80.939,-82.939,-85.939,-87.939,-85.939,-82.939,-82.939,-82.939,-84.939,-83.939,-83.939,-82.939,-82.939,-83.939,-83.939,-83.939,-84.939,-85.939,-85.939,-83.939,-83.939,-84.939,-84.939,-86.939,-89.939,-87.939,-84.939,-82.939,-82.939,-83.939,-83.939,-84.939,-86.939,-85.939,-84.939,-83.939,-81.939,-80.939,-79.939,-79.939,-81.939,-80.939,-80.939,-83.939,-83.939,-81.939,-81.939,-82.939,-85.939,-85.939,-85.939,-87.939,-86.939,-86.939,-85.939,-85.939,-85.939,-86.939,-86.939,-85.939,-87.939,-91.939,-89.939,-87.939,-83.939,-85.939,-86.939,-86.939,-85.939,-84.939,-89.939,-91.939,-88.939,-89.939,-89.939,-89.939,-88.939,-86.939,-88.939,-90.939,-91.939,-88.939,-87.939,-87.939,-87.939,-88.939,-88.939,-89.939,-91.939,-89.939,-89.939,-91.939,-93.939,-93.939,-88.939,-88.939,-90.939,-89.939,-89.939,-89.939,-93.939,-97.939,-96.939,-94.939,-95.939,-97.939,-98.939,-94.939,-93.939,-93.939,-92.939,-92.939,-91.939,-95.939,-99.939,-101.939,-102.939,-102.939,-93.939,-80.939,-61.939003,-56.939003,-59.939003,-79.939,-85.939,-84.939,-74.939,-69.939,-68.939,-66.939,-66.939,-71.939,-72.939,-71.939,-67.939,-68.939,-71.939,-72.939,-72.939,-73.939,-68.939,-62.939003,-60.939003,-59.939003,-61.939003,-62.939003,-62.939003,-58.939003,-56.939003,-53.939003,-42.939003,-43.939003,-55.939003,-55.939003,-53.939003,-47.939003,-47.939003,-50.939003,-49.939003,-49.939003,-51.939003,-55.939003,-59.939003,58.060997,53.060997,46.060997,43.060997,24.060997,-8.939003,3.060997,17.060997,22.060997,18.060997,10.060997,4.060997,0.06099701,-3.939003,-8.939003,-11.939003,-13.939003,-15.939003,-18.939003,-21.939003,-20.939003,-16.939003,-6.939003,1.060997,5.060997,6.060997,5.060997,2.060997,-0.939003,-7.939003,-11.939003,-16.939003,-26.939003,-31.939003,-34.939003,-42.939003,-50.939003,-60.939003,-67.939,-73.939,-80.939,-83.939,-85.939,-88.939,-92.939,-96.939,-94.939,-92.939,-93.939,-61.939003,-22.939003,-32.939003,-50.939003,-76.939,-78.939,-75.939,-68.939,-63.939003,-58.939003,-53.939003,-47.939003,-39.939003,-36.939003,-32.939003,-25.939003,-19.939003,-13.939003,-8.939003,-5.939003,-3.939003,2.060997,7.060997,11.060997,12.060997,13.060997,15.060997,17.060997,19.060997,18.060997,16.060997,11.060997,11.060997,11.060997,6.060997,-1.939003,-12.939003,-4.939003,10.060997,38.060997,42.060997,39.060997,41.060997,9.060997,-55.939003,-68.939,-75.939,-80.939,-85.939,-91.939,-97.939,-98.939,-97.939,-95.939,-94.939,-92.939,-90.939,-88.939,-87.939,-84.939,-80.939,-77.939,-74.939,-71.939,-64.939,-56.939003,-53.939003,-47.939003,-41.939003,-35.939003,-29.939003,-26.939003,-19.939003,-12.939003,-7.939003,-1.939003,4.060997,5.060997,6.060997,11.060997,13.060997,16.060997,17.060997,17.060997,15.060997,14.060997,13.060997,15.060997,12.060997,6.060997,-2.939003,-9.939003,-14.939003,-19.939003,-24.939003,-33.939003,-40.939003,-46.939003,-53.939003,-60.939003,-69.939,-81.939,-61.939003,34.060997,74.061,90.061,102.061,76.061,12.060997,-50.939003,-92.939,-85.939,-82.939,-80.939,-79.939,-76.939,-74.939,-70.939,-67.939,-61.939003,-55.939003,-47.939003,-39.939003,-32.939003,-27.939003,-19.939003,-13.939003,-7.939003,-3.939003,0.06099701,6.060997,10.060997,14.060997,17.060997,18.060997,19.060997,17.060997,14.060997,12.060997,10.060997,9.060997,2.060997,-3.939003,-4.939003,-9.939003,-15.939003,-22.939003,-28.939003,-34.939003,-40.939003,-47.939003,-56.939003,-64.939,-75.939,-81.939,-87.939,-94.939,-97.939,-98.939,-95.939,-93.939,-89.939,-56.939003,-34.939003,-20.939003,-7.939003,10.060997,38.060997,47.060997,46.060997,38.060997,29.060997,20.060997,21.060997,21.060997,17.060997,13.060997,7.060997,3.060997,-2.939003,-10.939003,-10.939003,-9.939003,-10.939003,-16.939003,-22.939003,-18.939003,-13.939003,-8.939003,-6.939003,-4.939003,2.060997,3.060997,3.060997,6.060997,10.060997,15.060997,21.060997,26.060997,29.060997,34.060997,39.060997,20.060997,22.060997,43.060997,59.060997,42.060997,-48.939003,-87.939,-101.939,-95.939,-90.939,-85.939,-88.939,-90.939,-87.939,-94.939,-103.939,-102.939,-99.939,-92.939,-78.939,-67.939,-67.939,-65.939,-61.939003,-54.939003,-46.939003,-37.939003,-33.939003,-28.939003,-19.939003,-12.939003,-8.939003,-7.939003,-3.939003,3.060997,10.060997,15.060997,16.060997,17.060997,17.060997,17.060997,15.060997,12.060997,10.060997,9.060997,5.060997,5.060997,4.060997,-15.939003,-49.939003,-98.939,-97.939,-92.939,-91.939,-90.939,-90.939,-84.939,-84.939,-90.939,-97.939,-101.939,-100.939,-99.939,-98.939,-93.939,-52.939003,24.060997,49.060997,60.060997,55.060997,9.060997,-48.939003,-62.939003,-66.939,-62.939003,-56.939003,-48.939003,-42.939003,-36.939003,-31.939003,-37.939003,-47.939003,-62.939003,-69.939,-74.939,-76.939,-75.939,-73.939,-74.939,-76.939,-80.939,-79.939,-78.939,-79.939,-79.939,-79.939,-79.939,-79.939,-78.939,-76.939,-74.939,-71.939,-72.939,-74.939,-74.939,-74.939,-77.939,-75.939,-73.939,-73.939,-74.939,-75.939,-77.939,-79.939,-77.939,-78.939,-79.939,-77.939,-76.939,-76.939,-77.939,-77.939,-78.939,-77.939,-76.939,-76.939,-76.939,-76.939,-78.939,-78.939,-77.939,-76.939,-76.939,-77.939,-78.939,-78.939,-78.939,-76.939,-73.939,-74.939,-74.939,-72.939,-73.939,-73.939,-74.939,-75.939,-78.939,-77.939,-76.939,-73.939,-73.939,-75.939,-77.939,-77.939,-77.939,-76.939,-76.939,-77.939,-77.939,-77.939,-78.939,-77.939,-77.939,-75.939,-76.939,-77.939,-73.939,-69.939,-73.939,-75.939,-76.939,-75.939,-74.939,-72.939,-73.939,-76.939,-74.939,-73.939,-75.939,-76.939,-76.939,-77.939,-76.939,-74.939,-76.939,-79.939,-81.939,-80.939,-78.939,-73.939,-74.939,-77.939,-75.939,-74.939,-73.939,-73.939,-72.939,-73.939,-75.939,-76.939,-77.939,-76.939,-74.939,-76.939,-78.939,-77.939,-75.939,-74.939,-77.939,-76.939,-71.939,-73.939,-75.939,-78.939,-77.939,-74.939,-75.939,-75.939,-78.939,-76.939,-74.939,-78.939,-79.939,-78.939,-79.939,-79.939,-76.939,-76.939,-77.939,-78.939,-78.939,-76.939,-78.939,-78.939,-75.939,-78.939,-80.939,-79.939,-76.939,-73.939,-75.939,-75.939,-75.939,-76.939,-78.939,-80.939,-79.939,-77.939,-79.939,-79.939,-77.939,-79.939,-80.939,-81.939,-80.939,-78.939,-80.939,-81.939,-81.939,-82.939,-82.939,-83.939,-86.939,-89.939,-88.939,-85.939,-82.939,-83.939,-82.939,-77.939,-74.939,-70.939,-70.939,-68.939,-66.939,-69.939,-70.939,-66.939,-65.939,-64.939,-65.939,-64.939,-61.939003,-61.939003,-61.939003,-63.939003,-62.939003,-60.939003,-58.939003,-59.939003,-63.939003,-62.939003,-63.939003,-64.939,-64.939,-63.939003,-63.939003,-65.939,-66.939,-69.939,-71.939,-69.939,-69.939,-70.939,-72.939,-74.939,-73.939,-75.939,-76.939,-77.939,-80.939,-82.939,-85.939,-87.939,-85.939,-83.939,-81.939,-79.939,-80.939,-82.939,-80.939,-79.939,-76.939,-79.939,-81.939,-79.939,-79.939,-79.939,-80.939,-81.939,-81.939,-84.939,-86.939,-84.939,-83.939,-81.939,-84.939,-85.939,-84.939,-83.939,-82.939,-83.939,-82.939,-81.939,-84.939,-85.939,-86.939,-84.939,-83.939,-85.939,-86.939,-86.939,-83.939,-81.939,-82.939,-82.939,-83.939,-83.939,-82.939,-82.939,-83.939,-83.939,-84.939,-82.939,-79.939,-78.939,-80.939,-83.939,-83.939,-84.939,-86.939,-88.939,-89.939,-88.939,-84.939,-79.939,-83.939,-83.939,-79.939,-78.939,-77.939,-77.939,-75.939,-74.939,-75.939,-78.939,-81.939,-82.939,-83.939,-86.939,-85.939,-83.939,-82.939,-82.939,-82.939,-82.939,-82.939,-82.939,-81.939,-81.939,-79.939,-79.939,-81.939,-84.939,-86.939,-85.939,-86.939,-85.939,-83.939,-83.939,-85.939,-85.939,-85.939,-83.939,-82.939,-82.939,-83.939,-85.939,-86.939,-85.939,-84.939,-84.939,-82.939,-80.939,-80.939,-80.939,-82.939,-80.939,-80.939,-82.939,-81.939,-80.939,-81.939,-82.939,-84.939,-85.939,-85.939,-86.939,-85.939,-84.939,-84.939,-85.939,-87.939,-87.939,-86.939,-84.939,-85.939,-87.939,-88.939,-87.939,-84.939,-86.939,-87.939,-87.939,-85.939,-83.939,-88.939,-89.939,-88.939,-88.939,-89.939,-88.939,-86.939,-84.939,-87.939,-90.939,-91.939,-88.939,-86.939,-86.939,-86.939,-87.939,-88.939,-89.939,-90.939,-89.939,-88.939,-91.939,-93.939,-92.939,-89.939,-89.939,-89.939,-88.939,-87.939,-89.939,-92.939,-96.939,-95.939,-94.939,-95.939,-97.939,-97.939,-93.939,-93.939,-95.939,-94.939,-93.939,-90.939,-95.939,-99.939,-100.939,-102.939,-101.939,-95.939,-83.939,-65.939,-62.939003,-67.939,-82.939,-85.939,-82.939,-72.939,-69.939,-70.939,-66.939,-65.939,-70.939,-73.939,-74.939,-66.939,-61.939003,-58.939003,-62.939003,-66.939,-66.939,-64.939,-61.939003,-60.939003,-62.939003,-65.939,-63.939003,-62.939003,-62.939003,-57.939003,-50.939003,-46.939003,-48.939003,-58.939003,-56.939003,-52.939003,-49.939003,-50.939003,-52.939003,-49.939003,-48.939003,-49.939003,-53.939003,-56.939003,13.060997,1.060997,-12.939003,-12.939003,-14.939003,-17.939003,-18.939003,-21.939003,-34.939003,-38.939003,-36.939003,-30.939003,-25.939003,-22.939003,-16.939003,-10.939003,-6.939003,-6.939003,-6.939003,1.060997,7.060997,10.060997,-7.939003,-27.939003,-47.939003,-54.939003,-57.939003,-62.939003,-69.939,-78.939,-80.939,-83.939,-89.939,-94.939,-97.939,-97.939,-98.939,-99.939,-100.939,-100.939,-101.939,-98.939,-94.939,-91.939,-87.939,-83.939,-77.939,-71.939,-72.939,-51.939003,-25.939003,-21.939003,-23.939003,-30.939003,-28.939003,-23.939003,-17.939003,-10.939003,-3.939003,1.060997,8.060997,16.060997,19.060997,22.060997,29.060997,35.060997,39.060997,36.060997,33.060997,29.060997,27.060997,26.060997,24.060997,19.060997,14.060997,2.060997,-6.939003,-10.939003,-21.939003,-31.939003,-42.939003,-48.939003,-53.939003,-58.939003,-71.939,-91.939,-44.939003,12.060997,77.061,88.061,80.061,94.061,40.060997,-80.939,-97.939,-100.939,-101.939,-100.939,-100.939,-98.939,-93.939,-86.939,-80.939,-76.939,-71.939,-65.939,-57.939003,-54.939003,-47.939003,-35.939003,-27.939003,-19.939003,-12.939003,-3.939003,4.060997,9.060997,15.060997,23.060997,29.060997,34.060997,33.060997,33.060997,32.060997,33.060997,33.060997,35.060997,29.060997,23.060997,22.060997,14.060997,3.060997,-5.939003,-16.939003,-29.939003,-40.939003,-50.939003,-54.939003,-63.939003,-75.939,-87.939,-94.939,-95.939,-94.939,-93.939,-94.939,-95.939,-97.939,-98.939,-99.939,-100.939,-101.939,-73.939,24.060997,68.061,85.061,77.061,49.060997,0.06099701,-45.939003,-70.939,-50.939003,-39.939003,-34.939003,-30.939003,-23.939003,-15.939003,-7.939003,-0.939003,5.060997,13.060997,22.060997,27.060997,31.060997,31.060997,33.060997,34.060997,36.060997,34.060997,32.060997,31.060997,27.060997,23.060997,12.060997,0.06099701,-7.939003,-19.939003,-31.939003,-40.939003,-49.939003,-58.939003,-71.939,-80.939,-85.939,-90.939,-95.939,-96.939,-96.939,-97.939,-97.939,-98.939,-99.939,-99.939,-100.939,-101.939,-102.939,-102.939,-95.939,-87.939,-80.939,-74.939,-68.939,-43.939003,-23.939003,-7.939003,-6.939003,-6.939003,-8.939003,-11.939003,-13.939003,-17.939003,-22.939003,-26.939003,-24.939003,-23.939003,-20.939003,-21.939003,-24.939003,-16.939003,-7.939003,2.060997,0.06099701,-1.939003,1.060997,1.060997,2.060997,4.060997,10.060997,18.060997,28.060997,38.060997,47.060997,50.060997,52.060997,57.060997,62.060997,67.061,76.061,82.061,82.061,85.061,86.061,40.060997,37.060997,75.061,98.061,85.061,-9.939003,-66.939,-102.939,-96.939,-92.939,-90.939,-93.939,-95.939,-92.939,-97.939,-103.939,-103.939,-93.939,-72.939,-29.939003,3.060997,3.060997,8.060997,15.060997,22.060997,29.060997,37.060997,36.060997,36.060997,39.060997,38.060997,34.060997,29.060997,27.060997,28.060997,21.060997,13.060997,5.060997,-3.939003,-14.939003,-21.939003,-31.939003,-44.939003,-54.939003,-61.939003,-63.939003,-72.939,-85.939,-88.939,-94.939,-103.939,-100.939,-96.939,-94.939,-91.939,-89.939,-91.939,-94.939,-97.939,-100.939,-100.939,-99.939,-94.939,-89.939,-74.939,-48.939003,-13.939003,-4.939003,-3.939003,-14.939003,-18.939003,-19.939003,-7.939003,2.060997,9.060997,14.060997,22.060997,37.060997,20.060997,-10.939003,-42.939003,-62.939003,-71.939,-75.939,-79.939,-83.939,-80.939,-76.939,-79.939,-80.939,-81.939,-80.939,-79.939,-77.939,-78.939,-79.939,-76.939,-74.939,-73.939,-71.939,-67.939,-66.939,-70.939,-77.939,-76.939,-73.939,-71.939,-69.939,-68.939,-72.939,-73.939,-73.939,-74.939,-76.939,-77.939,-78.939,-79.939,-77.939,-77.939,-77.939,-77.939,-76.939,-75.939,-76.939,-78.939,-80.939,-79.939,-77.939,-74.939,-72.939,-70.939,-72.939,-76.939,-79.939,-80.939,-79.939,-78.939,-76.939,-74.939,-73.939,-73.939,-71.939,-70.939,-68.939,-73.939,-76.939,-77.939,-74.939,-71.939,-70.939,-72.939,-74.939,-78.939,-79.939,-76.939,-77.939,-77.939,-78.939,-76.939,-74.939,-75.939,-75.939,-76.939,-74.939,-75.939,-81.939,-78.939,-74.939,-76.939,-77.939,-78.939,-77.939,-76.939,-74.939,-74.939,-76.939,-77.939,-76.939,-74.939,-73.939,-73.939,-74.939,-75.939,-76.939,-77.939,-78.939,-79.939,-80.939,-80.939,-78.939,-75.939,-73.939,-73.939,-74.939,-76.939,-77.939,-77.939,-78.939,-78.939,-78.939,-77.939,-75.939,-72.939,-75.939,-77.939,-78.939,-75.939,-73.939,-78.939,-78.939,-71.939,-74.939,-76.939,-73.939,-74.939,-77.939,-76.939,-76.939,-78.939,-77.939,-77.939,-78.939,-79.939,-79.939,-78.939,-77.939,-74.939,-73.939,-74.939,-75.939,-75.939,-73.939,-77.939,-78.939,-75.939,-76.939,-76.939,-75.939,-75.939,-75.939,-75.939,-75.939,-76.939,-78.939,-79.939,-81.939,-80.939,-78.939,-80.939,-80.939,-79.939,-80.939,-80.939,-82.939,-81.939,-79.939,-81.939,-83.939,-84.939,-83.939,-83.939,-84.939,-87.939,-89.939,-89.939,-88.939,-85.939,-84.939,-82.939,-78.939,-77.939,-76.939,-77.939,-75.939,-69.939,-70.939,-70.939,-66.939,-64.939,-62.939003,-64.939,-65.939,-66.939,-64.939,-62.939003,-60.939003,-61.939003,-62.939003,-63.939003,-64.939,-64.939,-63.939003,-63.939003,-64.939,-67.939,-68.939,-71.939,-72.939,-71.939,-70.939,-70.939,-72.939,-73.939,-75.939,-71.939,-71.939,-76.939,-78.939,-78.939,-76.939,-77.939,-80.939,-83.939,-84.939,-81.939,-82.939,-82.939,-80.939,-82.939,-85.939,-82.939,-80.939,-77.939,-80.939,-82.939,-80.939,-79.939,-78.939,-77.939,-78.939,-80.939,-81.939,-81.939,-80.939,-80.939,-80.939,-82.939,-82.939,-81.939,-80.939,-79.939,-80.939,-79.939,-77.939,-79.939,-82.939,-87.939,-84.939,-82.939,-83.939,-85.939,-87.939,-85.939,-82.939,-79.939,-81.939,-84.939,-86.939,-85.939,-82.939,-82.939,-83.939,-85.939,-81.939,-78.939,-77.939,-79.939,-81.939,-82.939,-83.939,-84.939,-83.939,-83.939,-83.939,-82.939,-80.939,-85.939,-86.939,-81.939,-80.939,-80.939,-82.939,-77.939,-71.939,-72.939,-75.939,-81.939,-79.939,-78.939,-81.939,-82.939,-81.939,-80.939,-80.939,-80.939,-79.939,-79.939,-79.939,-78.939,-77.939,-74.939,-74.939,-76.939,-80.939,-83.939,-84.939,-85.939,-86.939,-83.939,-80.939,-78.939,-82.939,-86.939,-86.939,-84.939,-82.939,-83.939,-84.939,-85.939,-83.939,-82.939,-84.939,-84.939,-82.939,-83.939,-83.939,-84.939,-81.939,-79.939,-81.939,-80.939,-79.939,-80.939,-82.939,-83.939,-84.939,-84.939,-85.939,-84.939,-83.939,-84.939,-86.939,-88.939,-88.939,-87.939,-85.939,-83.939,-82.939,-84.939,-86.939,-87.939,-88.939,-88.939,-87.939,-84.939,-81.939,-85.939,-86.939,-85.939,-87.939,-88.939,-85.939,-82.939,-80.939,-83.939,-86.939,-87.939,-84.939,-82.939,-81.939,-82.939,-84.939,-86.939,-88.939,-89.939,-88.939,-87.939,-87.939,-88.939,-89.939,-90.939,-91.939,-90.939,-89.939,-88.939,-90.939,-90.939,-91.939,-90.939,-90.939,-92.939,-93.939,-93.939,-91.939,-93.939,-96.939,-95.939,-93.939,-91.939,-96.939,-100.939,-101.939,-101.939,-100.939,-96.939,-85.939,-68.939,-67.939,-71.939,-82.939,-83.939,-79.939,-71.939,-68.939,-71.939,-68.939,-66.939,-66.939,-70.939,-75.939,-68.939,-58.939003,-48.939003,-47.939003,-47.939003,-47.939003,-53.939003,-62.939003,-65.939,-66.939,-62.939003,-62.939003,-62.939003,-63.939003,-56.939003,-47.939003,-55.939003,-57.939003,-55.939003,-51.939003,-49.939003,-54.939003,-56.939003,-57.939003,-53.939003,-49.939003,-48.939003,-52.939003,-56.939003,-17.939003,-16.939003,-16.939003,-13.939003,-12.939003,-11.939003,-9.939003,-7.939003,-5.939003,-3.939003,-0.939003,9.060997,17.060997,22.060997,29.060997,36.060997,41.060997,42.060997,45.060997,51.060997,57.060997,59.060997,-0.939003,-52.939003,-76.939,-84.939,-85.939,-88.939,-89.939,-89.939,-84.939,-80.939,-80.939,-81.939,-81.939,-75.939,-71.939,-67.939,-63.939003,-59.939003,-56.939003,-50.939003,-44.939003,-39.939003,-34.939003,-26.939003,-22.939003,-19.939003,-19.939003,-12.939003,-5.939003,-5.939003,-1.939003,6.060997,12.060997,16.060997,16.060997,15.060997,15.060997,12.060997,10.060997,9.060997,3.060997,-1.939003,-3.939003,-6.939003,-10.939003,-18.939003,-26.939003,-32.939003,-36.939003,-41.939003,-45.939003,-49.939003,-51.939003,-57.939003,-61.939003,-64.939,-69.939,-74.939,-78.939,-82.939,-84.939,-84.939,-90.939,-101.939,-55.939003,2.060997,71.061,75.061,55.060997,65.061,28.060997,-54.939003,-67.939,-68.939,-59.939003,-52.939003,-46.939003,-42.939003,-36.939003,-28.939003,-22.939003,-17.939003,-13.939003,-6.939003,0.06099701,0.06099701,3.060997,9.060997,13.060997,15.060997,16.060997,16.060997,14.060997,12.060997,8.060997,4.060997,2.060997,-1.939003,-7.939003,-14.939003,-23.939003,-27.939003,-33.939003,-39.939003,-43.939003,-47.939003,-47.939003,-51.939003,-57.939003,-61.939003,-66.939,-73.939,-78.939,-82.939,-85.939,-88.939,-93.939,-93.939,-91.939,-86.939,-81.939,-78.939,-74.939,-71.939,-68.939,-62.939003,-59.939003,-57.939003,-51.939003,-35.939003,-0.939003,14.060997,19.060997,13.060997,4.060997,-7.939003,-6.939003,-2.939003,7.060997,11.060997,13.060997,14.060997,15.060997,15.060997,13.060997,11.060997,8.060997,4.060997,1.060997,1.060997,-2.939003,-10.939003,-17.939003,-24.939003,-28.939003,-35.939003,-42.939003,-43.939003,-45.939003,-47.939003,-53.939003,-58.939003,-62.939003,-68.939,-74.939,-78.939,-82.939,-86.939,-91.939,-94.939,-94.939,-92.939,-88.939,-83.939,-77.939,-71.939,-68.939,-64.939,-57.939003,-52.939003,-47.939003,-42.939003,-39.939003,-38.939003,-30.939003,-21.939003,-16.939003,-11.939003,-7.939003,-0.939003,2.060997,1.060997,0.06099701,-1.939003,-8.939003,-12.939003,-14.939003,-12.939003,-8.939003,-4.939003,-2.939003,-0.939003,5.060997,12.060997,17.060997,16.060997,12.060997,5.060997,17.060997,27.060997,27.060997,45.060997,68.061,59.060997,53.060997,49.060997,61.060997,70.061,68.061,68.061,69.061,64.061,61.060997,60.060997,61.060997,59.060997,51.060997,48.060997,45.060997,14.060997,9.060997,30.060997,42.060997,33.060997,-21.939003,-66.939,-102.939,-95.939,-93.939,-95.939,-97.939,-98.939,-98.939,-100.939,-103.939,-103.939,-84.939,-46.939003,0.06099701,27.060997,8.060997,3.060997,3.060997,3.060997,1.060997,-1.939003,-10.939003,-18.939003,-20.939003,-25.939003,-32.939003,-40.939003,-44.939003,-44.939003,-48.939003,-52.939003,-55.939003,-60.939003,-65.939,-69.939,-74.939,-80.939,-85.939,-87.939,-87.939,-88.939,-91.939,-92.939,-92.939,-93.939,-93.939,-93.939,-92.939,-89.939,-86.939,-89.939,-89.939,-87.939,-89.939,-92.939,-99.939,-82.939,-55.939003,-12.939003,2.060997,-7.939003,-13.939003,-17.939003,-19.939003,-11.939003,-0.939003,4.060997,7.060997,6.060997,7.060997,5.060997,-3.939003,-26.939003,-54.939003,-66.939,-72.939,-74.939,-76.939,-78.939,-79.939,-79.939,-79.939,-77.939,-76.939,-75.939,-72.939,-71.939,-74.939,-77.939,-79.939,-73.939,-71.939,-73.939,-73.939,-70.939,-64.939,-68.939,-75.939,-75.939,-73.939,-70.939,-70.939,-71.939,-74.939,-74.939,-72.939,-71.939,-73.939,-76.939,-77.939,-78.939,-76.939,-77.939,-78.939,-78.939,-76.939,-74.939,-74.939,-75.939,-77.939,-74.939,-71.939,-73.939,-74.939,-75.939,-75.939,-74.939,-73.939,-74.939,-76.939,-77.939,-76.939,-73.939,-72.939,-72.939,-73.939,-71.939,-68.939,-70.939,-71.939,-72.939,-72.939,-71.939,-72.939,-74.939,-74.939,-77.939,-77.939,-74.939,-75.939,-76.939,-77.939,-75.939,-73.939,-74.939,-74.939,-77.939,-75.939,-75.939,-78.939,-77.939,-74.939,-77.939,-78.939,-77.939,-75.939,-74.939,-73.939,-74.939,-75.939,-75.939,-74.939,-73.939,-73.939,-74.939,-76.939,-77.939,-78.939,-76.939,-75.939,-74.939,-75.939,-75.939,-76.939,-76.939,-76.939,-74.939,-75.939,-76.939,-76.939,-76.939,-76.939,-77.939,-77.939,-78.939,-77.939,-75.939,-75.939,-75.939,-76.939,-75.939,-74.939,-77.939,-77.939,-73.939,-76.939,-78.939,-74.939,-74.939,-76.939,-75.939,-76.939,-77.939,-78.939,-79.939,-80.939,-81.939,-80.939,-76.939,-75.939,-74.939,-75.939,-76.939,-73.939,-73.939,-72.939,-77.939,-79.939,-77.939,-75.939,-74.939,-74.939,-76.939,-77.939,-77.939,-77.939,-76.939,-77.939,-78.939,-78.939,-78.939,-77.939,-79.939,-79.939,-80.939,-79.939,-78.939,-80.939,-80.939,-80.939,-82.939,-84.939,-86.939,-84.939,-82.939,-82.939,-84.939,-85.939,-87.939,-87.939,-87.939,-85.939,-81.939,-77.939,-79.939,-82.939,-82.939,-79.939,-73.939,-73.939,-72.939,-70.939,-68.939,-65.939,-68.939,-70.939,-72.939,-69.939,-65.939,-62.939003,-63.939003,-66.939,-70.939,-70.939,-68.939,-67.939,-68.939,-70.939,-70.939,-70.939,-73.939,-76.939,-77.939,-76.939,-74.939,-74.939,-75.939,-76.939,-73.939,-73.939,-77.939,-79.939,-79.939,-75.939,-76.939,-79.939,-80.939,-80.939,-79.939,-82.939,-84.939,-83.939,-84.939,-84.939,-82.939,-81.939,-79.939,-82.939,-84.939,-81.939,-80.939,-80.939,-79.939,-79.939,-81.939,-80.939,-79.939,-82.939,-82.939,-82.939,-82.939,-81.939,-79.939,-79.939,-80.939,-82.939,-82.939,-80.939,-80.939,-82.939,-87.939,-85.939,-83.939,-83.939,-84.939,-85.939,-86.939,-85.939,-83.939,-83.939,-83.939,-83.939,-83.939,-83.939,-81.939,-81.939,-82.939,-79.939,-78.939,-78.939,-78.939,-78.939,-80.939,-82.939,-83.939,-84.939,-83.939,-84.939,-83.939,-81.939,-84.939,-84.939,-79.939,-78.939,-77.939,-79.939,-78.939,-75.939,-74.939,-76.939,-80.939,-78.939,-77.939,-81.939,-83.939,-83.939,-82.939,-81.939,-82.939,-80.939,-79.939,-78.939,-79.939,-79.939,-76.939,-75.939,-77.939,-79.939,-80.939,-81.939,-82.939,-83.939,-83.939,-82.939,-79.939,-83.939,-85.939,-84.939,-83.939,-82.939,-84.939,-85.939,-84.939,-81.939,-80.939,-82.939,-82.939,-81.939,-82.939,-82.939,-82.939,-80.939,-80.939,-82.939,-82.939,-81.939,-81.939,-82.939,-84.939,-83.939,-83.939,-84.939,-84.939,-84.939,-85.939,-86.939,-86.939,-86.939,-86.939,-84.939,-83.939,-83.939,-83.939,-84.939,-86.939,-87.939,-87.939,-84.939,-82.939,-82.939,-86.939,-87.939,-84.939,-87.939,-89.939,-86.939,-83.939,-81.939,-84.939,-87.939,-88.939,-85.939,-83.939,-82.939,-83.939,-85.939,-87.939,-87.939,-86.939,-87.939,-88.939,-85.939,-87.939,-89.939,-88.939,-89.939,-89.939,-90.939,-91.939,-92.939,-92.939,-90.939,-91.939,-91.939,-92.939,-93.939,-94.939,-92.939,-94.939,-97.939,-95.939,-94.939,-93.939,-97.939,-100.939,-102.939,-102.939,-100.939,-99.939,-91.939,-74.939,-69.939,-69.939,-80.939,-81.939,-78.939,-77.939,-75.939,-72.939,-70.939,-67.939,-64.939,-60.939003,-59.939003,-60.939003,-59.939003,-57.939003,-53.939003,-51.939003,-48.939003,-49.939003,-51.939003,-51.939003,-52.939003,-55.939003,-64.939,-68.939,-57.939003,-54.939003,-54.939003,-58.939003,-58.939003,-54.939003,-55.939003,-56.939003,-56.939003,-53.939003,-50.939003,-51.939003,-51.939003,-51.939003,-54.939003,-57.939003,-16.939003,-9.939003,0.06099701,7.060997,6.060997,-0.939003,3.060997,14.060997,37.060997,43.060997,44.060997,54.060997,61.060997,66.061,71.061,74.061,77.061,78.061,81.061,83.061,85.061,85.061,4.060997,-63.939003,-86.939,-92.939,-90.939,-88.939,-84.939,-76.939,-66.939,-58.939003,-54.939003,-51.939003,-48.939003,-38.939003,-32.939003,-26.939003,-20.939003,-14.939003,-9.939003,-3.939003,3.060997,5.060997,10.060997,16.060997,17.060997,16.060997,15.060997,12.060997,6.060997,2.060997,5.060997,16.060997,21.060997,22.060997,17.060997,9.060997,1.060997,-5.939003,-13.939003,-20.939003,-30.939003,-40.939003,-46.939003,-54.939003,-62.939003,-72.939,-80.939,-85.939,-90.939,-94.939,-98.939,-98.939,-97.939,-96.939,-96.939,-96.939,-94.939,-91.939,-89.939,-88.939,-88.939,-83.939,-82.939,-83.939,-49.939003,-6.939003,46.060997,44.060997,21.060997,26.060997,10.060997,-25.939003,-27.939003,-24.939003,-14.939003,-5.939003,1.060997,4.060997,9.060997,15.060997,18.060997,20.060997,19.060997,23.060997,26.060997,21.060997,18.060997,18.060997,17.060997,14.060997,9.060997,2.060997,-5.939003,-11.939003,-20.939003,-31.939003,-39.939003,-47.939003,-55.939003,-65.939,-75.939,-82.939,-90.939,-98.939,-98.939,-98.939,-97.939,-95.939,-94.939,-93.939,-92.939,-90.939,-88.939,-87.939,-86.939,-84.939,-81.939,-73.939,-64.939,-55.939003,-48.939003,-43.939003,-37.939003,-31.939003,-25.939003,-16.939003,-11.939003,-10.939003,-1.939003,1.060997,-10.939003,-17.939003,-21.939003,-23.939003,-17.939003,-5.939003,15.060997,31.060997,30.060997,27.060997,23.060997,21.060997,17.060997,9.060997,1.060997,-6.939003,-15.939003,-25.939003,-35.939003,-38.939003,-46.939003,-57.939003,-69.939,-79.939,-86.939,-93.939,-100.939,-99.939,-97.939,-96.939,-95.939,-93.939,-92.939,-90.939,-88.939,-87.939,-85.939,-84.939,-81.939,-78.939,-73.939,-65.939,-55.939003,-47.939003,-39.939003,-29.939003,-24.939003,-19.939003,-10.939003,-3.939003,3.060997,9.060997,12.060997,12.060997,18.060997,23.060997,24.060997,24.060997,23.060997,19.060997,11.060997,1.060997,3.060997,6.060997,5.060997,6.060997,7.060997,13.060997,21.060997,30.060997,32.060997,34.060997,41.060997,51.060997,60.060997,50.060997,32.060997,5.060997,28.060997,46.060997,40.060997,67.061,102.061,86.061,70.061,56.060997,66.061,72.061,62.060997,59.060997,59.060997,48.060997,41.060997,37.060997,30.060997,23.060997,13.060997,8.060997,4.060997,-7.939003,-11.939003,-8.939003,-4.939003,-7.939003,-26.939003,-63.939003,-102.939,-96.939,-95.939,-98.939,-98.939,-99.939,-101.939,-102.939,-101.939,-102.939,-83.939,-41.939003,-3.939003,14.060997,-14.939003,-24.939003,-29.939003,-33.939003,-40.939003,-49.939003,-61.939003,-71.939,-77.939,-83.939,-90.939,-97.939,-100.939,-99.939,-99.939,-99.939,-99.939,-98.939,-98.939,-99.939,-99.939,-98.939,-98.939,-97.939,-97.939,-92.939,-87.939,-87.939,-86.939,-85.939,-86.939,-87.939,-88.939,-85.939,-83.939,-86.939,-85.939,-78.939,-79.939,-84.939,-95.939,-73.939,-36.939003,15.060997,27.060997,2.060997,-2.939003,-3.939003,1.060997,2.060997,1.060997,-8.939003,-14.939003,-20.939003,-19.939003,-25.939003,-49.939003,-68.939,-85.939,-80.939,-77.939,-76.939,-75.939,-74.939,-74.939,-76.939,-79.939,-75.939,-73.939,-71.939,-67.939,-66.939,-72.939,-76.939,-77.939,-69.939,-68.939,-73.939,-74.939,-73.939,-66.939,-68.939,-73.939,-73.939,-72.939,-69.939,-70.939,-72.939,-73.939,-73.939,-71.939,-69.939,-70.939,-75.939,-76.939,-77.939,-75.939,-76.939,-77.939,-77.939,-76.939,-74.939,-72.939,-72.939,-73.939,-71.939,-69.939,-72.939,-75.939,-78.939,-75.939,-72.939,-69.939,-71.939,-74.939,-75.939,-75.939,-72.939,-71.939,-71.939,-74.939,-72.939,-68.939,-68.939,-68.939,-69.939,-70.939,-72.939,-74.939,-75.939,-73.939,-75.939,-74.939,-72.939,-74.939,-75.939,-77.939,-76.939,-73.939,-74.939,-74.939,-76.939,-74.939,-74.939,-75.939,-74.939,-72.939,-76.939,-78.939,-77.939,-74.939,-72.939,-73.939,-74.939,-75.939,-73.939,-72.939,-72.939,-73.939,-74.939,-76.939,-77.939,-78.939,-76.939,-74.939,-72.939,-72.939,-72.939,-75.939,-77.939,-78.939,-75.939,-75.939,-76.939,-75.939,-75.939,-75.939,-76.939,-77.939,-77.939,-77.939,-77.939,-74.939,-73.939,-75.939,-74.939,-73.939,-75.939,-75.939,-74.939,-77.939,-78.939,-74.939,-74.939,-75.939,-74.939,-75.939,-77.939,-77.939,-78.939,-79.939,-80.939,-79.939,-75.939,-74.939,-75.939,-77.939,-77.939,-73.939,-72.939,-73.939,-76.939,-78.939,-77.939,-75.939,-72.939,-74.939,-76.939,-78.939,-78.939,-77.939,-76.939,-77.939,-78.939,-76.939,-76.939,-76.939,-78.939,-79.939,-80.939,-78.939,-77.939,-78.939,-79.939,-80.939,-82.939,-84.939,-86.939,-83.939,-81.939,-81.939,-82.939,-82.939,-86.939,-88.939,-89.939,-86.939,-82.939,-78.939,-80.939,-85.939,-84.939,-81.939,-77.939,-76.939,-75.939,-75.939,-73.939,-71.939,-73.939,-75.939,-75.939,-72.939,-68.939,-64.939,-66.939,-71.939,-75.939,-74.939,-71.939,-72.939,-73.939,-76.939,-74.939,-71.939,-74.939,-77.939,-80.939,-79.939,-78.939,-77.939,-77.939,-78.939,-74.939,-75.939,-78.939,-80.939,-80.939,-76.939,-77.939,-80.939,-79.939,-78.939,-77.939,-81.939,-84.939,-83.939,-84.939,-84.939,-82.939,-81.939,-81.939,-83.939,-85.939,-82.939,-81.939,-81.939,-80.939,-80.939,-81.939,-79.939,-78.939,-83.939,-84.939,-83.939,-81.939,-79.939,-78.939,-79.939,-80.939,-83.939,-83.939,-82.939,-81.939,-82.939,-86.939,-84.939,-83.939,-83.939,-83.939,-82.939,-85.939,-87.939,-86.939,-84.939,-82.939,-81.939,-82.939,-84.939,-81.939,-81.939,-80.939,-79.939,-79.939,-79.939,-79.939,-77.939,-80.939,-82.939,-82.939,-84.939,-85.939,-86.939,-84.939,-81.939,-83.939,-82.939,-79.939,-77.939,-76.939,-77.939,-78.939,-80.939,-77.939,-76.939,-79.939,-76.939,-75.939,-80.939,-83.939,-85.939,-82.939,-81.939,-81.939,-80.939,-78.939,-78.939,-79.939,-80.939,-77.939,-76.939,-77.939,-77.939,-77.939,-78.939,-78.939,-80.939,-82.939,-82.939,-81.939,-82.939,-84.939,-82.939,-82.939,-82.939,-84.939,-84.939,-83.939,-81.939,-80.939,-80.939,-81.939,-79.939,-81.939,-81.939,-79.939,-80.939,-81.939,-82.939,-82.939,-82.939,-81.939,-82.939,-85.939,-83.939,-82.939,-82.939,-83.939,-85.939,-86.939,-86.939,-85.939,-85.939,-85.939,-82.939,-83.939,-85.939,-83.939,-83.939,-85.939,-86.939,-86.939,-82.939,-81.939,-82.939,-87.939,-88.939,-84.939,-86.939,-88.939,-87.939,-84.939,-82.939,-85.939,-86.939,-87.939,-85.939,-83.939,-82.939,-83.939,-84.939,-87.939,-86.939,-83.939,-86.939,-88.939,-85.939,-87.939,-90.939,-87.939,-87.939,-88.939,-91.939,-92.939,-93.939,-92.939,-89.939,-91.939,-92.939,-91.939,-93.939,-94.939,-92.939,-93.939,-96.939,-95.939,-94.939,-94.939,-96.939,-99.939,-101.939,-102.939,-101.939,-102.939,-96.939,-80.939,-71.939,-68.939,-79.939,-79.939,-76.939,-81.939,-80.939,-74.939,-70.939,-67.939,-62.939003,-54.939003,-47.939003,-52.939003,-56.939003,-61.939003,-59.939003,-56.939003,-53.939003,-49.939003,-46.939003,-42.939003,-42.939003,-49.939003,-62.939003,-67.939,-50.939003,-52.939003,-60.939003,-62.939003,-60.939003,-54.939003,-58.939003,-61.939003,-58.939003,-52.939003,-46.939003,-51.939003,-54.939003,-55.939003,-57.939003,-58.939003,52.060997,52.060997,51.060997,62.060997,51.060997,21.060997,22.060997,39.060997,86.061,95.061,87.061,86.061,86.061,87.061,86.061,83.061,77.061,73.061,70.061,64.061,59.060997,55.060997,3.060997,-41.939003,-56.939003,-55.939003,-47.939003,-40.939003,-31.939003,-21.939003,-15.939003,-10.939003,-5.939003,-0.939003,5.060997,12.060997,16.060997,14.060997,17.060997,20.060997,25.060997,28.060997,28.060997,23.060997,19.060997,17.060997,12.060997,7.060997,0.06099701,-2.939003,-4.939003,-8.939003,-18.939003,-35.939003,-43.939003,-52.939003,-57.939003,-68.939,-78.939,-80.939,-82.939,-84.939,-86.939,-88.939,-89.939,-91.939,-94.939,-95.939,-95.939,-93.939,-91.939,-90.939,-87.939,-83.939,-78.939,-73.939,-71.939,-73.939,-63.939003,-53.939003,-44.939003,-40.939003,-36.939003,-32.939003,-24.939003,-14.939003,-11.939003,-9.939003,-11.939003,-14.939003,-16.939003,-15.939003,-10.939003,-2.939003,17.060997,30.060997,24.060997,23.060997,22.060997,17.060997,14.060997,14.060997,7.060997,-1.939003,-12.939003,-21.939003,-28.939003,-41.939003,-49.939003,-54.939003,-60.939003,-67.939,-75.939,-79.939,-80.939,-81.939,-83.939,-86.939,-88.939,-90.939,-92.939,-94.939,-97.939,-97.939,-95.939,-93.939,-86.939,-79.939,-75.939,-70.939,-64.939,-59.939003,-53.939003,-46.939003,-40.939003,-34.939003,-30.939003,-26.939003,-20.939003,-11.939003,-4.939003,3.060997,9.060997,15.060997,21.060997,27.060997,32.060997,32.060997,32.060997,30.060997,30.060997,26.060997,14.060997,10.060997,10.060997,14.060997,16.060997,16.060997,-6.939003,-27.939003,-34.939003,-43.939003,-51.939003,-56.939003,-64.939,-74.939,-78.939,-81.939,-83.939,-85.939,-87.939,-86.939,-88.939,-92.939,-95.939,-98.939,-97.939,-97.939,-95.939,-85.939,-77.939,-73.939,-67.939,-62.939003,-54.939003,-46.939003,-39.939003,-32.939003,-26.939003,-20.939003,-13.939003,-7.939003,-2.939003,6.060997,15.060997,17.060997,20.060997,25.060997,28.060997,30.060997,30.060997,30.060997,29.060997,24.060997,19.060997,16.060997,11.060997,4.060997,-3.939003,-15.939003,-26.939003,-22.939003,-19.939003,-20.939003,-4.939003,15.060997,38.060997,55.060997,68.061,69.061,72.061,77.061,80.061,82.061,83.061,84.061,85.061,71.061,44.060997,2.060997,23.060997,37.060997,20.060997,31.060997,51.060997,35.060997,20.060997,6.060997,7.060997,7.060997,0.06099701,-2.939003,-3.939003,-9.939003,-11.939003,-7.939003,-15.939003,-21.939003,-19.939003,-17.939003,-15.939003,-9.939003,-7.939003,-9.939003,-4.939003,-0.939003,-1.939003,-46.939003,-103.939,-99.939,-96.939,-95.939,-97.939,-98.939,-99.939,-97.939,-96.939,-101.939,-98.939,-89.939,-80.939,-76.939,-82.939,-85.939,-86.939,-86.939,-87.939,-88.939,-89.939,-91.939,-95.939,-98.939,-99.939,-97.939,-91.939,-84.939,-85.939,-86.939,-87.939,-85.939,-82.939,-85.939,-84.939,-79.939,-79.939,-80.939,-82.939,-81.939,-79.939,-79.939,-80.939,-83.939,-81.939,-79.939,-80.939,-80.939,-82.939,-85.939,-84.939,-78.939,-77.939,-79.939,-82.939,-72.939,-57.939003,-45.939003,-22.939003,10.060997,42.060997,65.061,66.061,23.060997,-36.939003,-66.939,-83.939,-83.939,-74.939,-68.939,-75.939,-73.939,-66.939,-69.939,-72.939,-79.939,-73.939,-68.939,-72.939,-74.939,-74.939,-73.939,-72.939,-73.939,-70.939,-68.939,-70.939,-72.939,-72.939,-68.939,-68.939,-73.939,-73.939,-71.939,-73.939,-73.939,-72.939,-70.939,-68.939,-67.939,-68.939,-69.939,-68.939,-69.939,-69.939,-70.939,-72.939,-73.939,-73.939,-72.939,-73.939,-73.939,-74.939,-76.939,-76.939,-75.939,-73.939,-70.939,-70.939,-72.939,-74.939,-71.939,-70.939,-72.939,-70.939,-70.939,-71.939,-72.939,-75.939,-74.939,-71.939,-68.939,-69.939,-70.939,-72.939,-71.939,-69.939,-67.939,-66.939,-67.939,-71.939,-73.939,-74.939,-73.939,-72.939,-72.939,-72.939,-70.939,-73.939,-77.939,-79.939,-78.939,-77.939,-75.939,-74.939,-74.939,-72.939,-72.939,-73.939,-73.939,-71.939,-75.939,-76.939,-77.939,-73.939,-71.939,-74.939,-75.939,-76.939,-72.939,-71.939,-71.939,-72.939,-72.939,-73.939,-74.939,-75.939,-76.939,-76.939,-77.939,-75.939,-74.939,-73.939,-75.939,-79.939,-76.939,-76.939,-78.939,-77.939,-77.939,-76.939,-76.939,-77.939,-75.939,-74.939,-73.939,-72.939,-72.939,-73.939,-73.939,-71.939,-71.939,-72.939,-73.939,-75.939,-75.939,-74.939,-72.939,-70.939,-71.939,-73.939,-75.939,-75.939,-73.939,-73.939,-74.939,-76.939,-77.939,-77.939,-78.939,-76.939,-75.939,-75.939,-75.939,-74.939,-74.939,-74.939,-76.939,-74.939,-74.939,-75.939,-76.939,-76.939,-73.939,-72.939,-74.939,-76.939,-77.939,-76.939,-75.939,-74.939,-79.939,-81.939,-81.939,-79.939,-78.939,-80.939,-79.939,-78.939,-79.939,-81.939,-84.939,-82.939,-81.939,-83.939,-83.939,-84.939,-88.939,-91.939,-93.939,-91.939,-87.939,-83.939,-82.939,-82.939,-80.939,-79.939,-78.939,-78.939,-80.939,-80.939,-81.939,-80.939,-79.939,-76.939,-72.939,-70.939,-69.939,-69.939,-71.939,-76.939,-75.939,-73.939,-68.939,-73.939,-78.939,-77.939,-75.939,-74.939,-76.939,-76.939,-75.939,-78.939,-80.939,-80.939,-80.939,-80.939,-76.939,-75.939,-78.939,-79.939,-80.939,-77.939,-80.939,-83.939,-82.939,-80.939,-79.939,-79.939,-79.939,-77.939,-80.939,-84.939,-83.939,-83.939,-83.939,-84.939,-83.939,-82.939,-82.939,-83.939,-82.939,-80.939,-79.939,-78.939,-79.939,-82.939,-82.939,-81.939,-78.939,-77.939,-76.939,-77.939,-78.939,-80.939,-80.939,-80.939,-80.939,-81.939,-82.939,-83.939,-82.939,-84.939,-82.939,-81.939,-84.939,-85.939,-88.939,-84.939,-80.939,-79.939,-81.939,-86.939,-83.939,-82.939,-84.939,-83.939,-82.939,-81.939,-80.939,-80.939,-81.939,-81.939,-77.939,-83.939,-88.939,-90.939,-86.939,-81.939,-81.939,-83.939,-84.939,-82.939,-81.939,-78.939,-79.939,-81.939,-78.939,-76.939,-74.939,-73.939,-73.939,-77.939,-81.939,-83.939,-80.939,-77.939,-76.939,-75.939,-76.939,-76.939,-79.939,-81.939,-76.939,-74.939,-75.939,-76.939,-76.939,-76.939,-77.939,-78.939,-79.939,-81.939,-83.939,-81.939,-79.939,-77.939,-79.939,-80.939,-82.939,-83.939,-83.939,-83.939,-83.939,-82.939,-80.939,-78.939,-79.939,-79.939,-77.939,-78.939,-80.939,-80.939,-81.939,-81.939,-79.939,-81.939,-87.939,-83.939,-80.939,-80.939,-82.939,-84.939,-86.939,-86.939,-85.939,-84.939,-83.939,-80.939,-82.939,-87.939,-83.939,-81.939,-82.939,-83.939,-84.939,-82.939,-82.939,-83.939,-86.939,-86.939,-84.939,-84.939,-84.939,-85.939,-84.939,-83.939,-84.939,-85.939,-85.939,-83.939,-82.939,-82.939,-82.939,-82.939,-85.939,-86.939,-83.939,-87.939,-90.939,-88.939,-90.939,-91.939,-87.939,-85.939,-87.939,-89.939,-91.939,-88.939,-87.939,-86.939,-89.939,-90.939,-90.939,-91.939,-92.939,-91.939,-91.939,-92.939,-94.939,-95.939,-92.939,-94.939,-95.939,-100.939,-102.939,-102.939,-103.939,-97.939,-85.939,-74.939,-68.939,-77.939,-77.939,-74.939,-79.939,-79.939,-75.939,-69.939,-64.939,-64.939,-59.939003,-51.939003,-43.939003,-43.939003,-48.939003,-56.939003,-61.939003,-61.939003,-60.939003,-57.939003,-50.939003,-47.939003,-48.939003,-49.939003,-48.939003,-45.939003,-53.939003,-65.939,-70.939,-67.939,-57.939003,-59.939003,-60.939003,-60.939003,-57.939003,-54.939003,-57.939003,-59.939003,-58.939003,-61.939003,-63.939003,84.061,77.061,70.061,77.061,64.061,31.060997,20.060997,27.060997,70.061,75.061,64.061,59.060997,56.060997,53.060997,49.060997,43.060997,33.060997,28.060997,26.060997,18.060997,13.060997,11.060997,-4.939003,-15.939003,-12.939003,-9.939003,-6.939003,-1.939003,2.060997,7.060997,8.060997,8.060997,8.060997,7.060997,8.060997,7.060997,3.060997,-0.939003,-3.939003,-6.939003,-10.939003,-12.939003,-14.939003,-22.939003,-29.939003,-34.939003,-38.939003,-42.939003,-47.939003,-43.939003,-35.939003,-19.939003,-30.939003,-68.939,-79.939,-86.939,-87.939,-92.939,-97.939,-94.939,-91.939,-87.939,-84.939,-81.939,-78.939,-74.939,-71.939,-67.939,-63.939003,-57.939003,-50.939003,-44.939003,-39.939003,-35.939003,-30.939003,-24.939003,-20.939003,-19.939003,-12.939003,-6.939003,1.060997,2.060997,2.060997,5.060997,8.060997,12.060997,7.060997,2.060997,-7.939003,-7.939003,-3.939003,3.060997,4.060997,-1.939003,-7.939003,-12.939003,-18.939003,-24.939003,-29.939003,-33.939003,-35.939003,-37.939003,-43.939003,-49.939003,-57.939003,-64.939,-70.939,-79.939,-84.939,-88.939,-91.939,-94.939,-94.939,-91.939,-86.939,-85.939,-83.939,-78.939,-74.939,-70.939,-70.939,-65.939,-59.939003,-55.939003,-51.939003,-45.939003,-36.939003,-29.939003,-23.939003,-18.939003,-13.939003,-9.939003,-5.939003,-0.939003,5.060997,9.060997,13.060997,14.060997,15.060997,15.060997,13.060997,12.060997,11.060997,9.060997,7.060997,5.060997,3.060997,-2.939003,-6.939003,-11.939003,-19.939003,-18.939003,2.060997,34.060997,68.061,67.061,64.061,60.060997,-12.939003,-70.939,-75.939,-81.939,-87.939,-90.939,-93.939,-95.939,-93.939,-90.939,-86.939,-80.939,-73.939,-69.939,-66.939,-64.939,-61.939003,-57.939003,-52.939003,-47.939003,-41.939003,-31.939003,-24.939003,-21.939003,-15.939003,-9.939003,-4.939003,1.060997,5.060997,8.060997,11.060997,14.060997,14.060997,13.060997,13.060997,13.060997,12.060997,6.060997,3.060997,2.060997,-0.939003,-5.939003,-13.939003,-19.939003,-24.939003,-29.939003,-33.939003,-35.939003,-39.939003,-45.939003,-51.939003,-60.939003,-68.939,-58.939003,-44.939003,-27.939003,-6.939003,14.060997,38.060997,56.060997,70.061,61.060997,56.060997,55.060997,55.060997,54.060997,51.060997,48.060997,44.060997,36.060997,21.060997,-0.939003,8.060997,14.060997,6.060997,6.060997,10.060997,1.060997,-4.939003,-9.939003,-9.939003,-9.939003,-12.939003,-13.939003,-12.939003,-10.939003,-7.939003,-4.939003,-6.939003,-6.939003,1.060997,6.060997,9.060997,4.060997,8.060997,21.060997,26.060997,34.060997,53.060997,0.06099701,-76.939,-92.939,-98.939,-93.939,-93.939,-94.939,-98.939,-99.939,-97.939,-96.939,-92.939,-85.939,-84.939,-84.939,-84.939,-82.939,-79.939,-77.939,-77.939,-79.939,-82.939,-85.939,-87.939,-89.939,-88.939,-87.939,-83.939,-78.939,-77.939,-77.939,-80.939,-80.939,-80.939,-81.939,-80.939,-76.939,-77.939,-78.939,-78.939,-78.939,-76.939,-77.939,-79.939,-81.939,-80.939,-79.939,-80.939,-81.939,-82.939,-83.939,-82.939,-79.939,-78.939,-77.939,-74.939,-73.939,-73.939,-78.939,-62.939003,-26.939003,16.060997,50.060997,62.060997,24.060997,-33.939003,-66.939,-85.939,-88.939,-81.939,-75.939,-76.939,-72.939,-66.939,-66.939,-68.939,-74.939,-71.939,-68.939,-69.939,-70.939,-70.939,-71.939,-71.939,-72.939,-70.939,-69.939,-70.939,-70.939,-68.939,-67.939,-68.939,-69.939,-69.939,-70.939,-72.939,-71.939,-67.939,-68.939,-68.939,-68.939,-68.939,-69.939,-68.939,-69.939,-71.939,-70.939,-72.939,-74.939,-73.939,-72.939,-73.939,-72.939,-72.939,-73.939,-75.939,-77.939,-74.939,-71.939,-72.939,-73.939,-74.939,-71.939,-70.939,-73.939,-70.939,-69.939,-69.939,-71.939,-75.939,-73.939,-70.939,-67.939,-69.939,-71.939,-71.939,-69.939,-68.939,-67.939,-67.939,-68.939,-71.939,-73.939,-72.939,-73.939,-74.939,-72.939,-71.939,-70.939,-72.939,-74.939,-75.939,-75.939,-75.939,-74.939,-73.939,-73.939,-72.939,-73.939,-74.939,-74.939,-72.939,-73.939,-74.939,-74.939,-73.939,-74.939,-76.939,-75.939,-75.939,-71.939,-71.939,-74.939,-74.939,-74.939,-73.939,-74.939,-74.939,-76.939,-76.939,-77.939,-76.939,-76.939,-76.939,-77.939,-78.939,-75.939,-75.939,-76.939,-75.939,-74.939,-74.939,-75.939,-77.939,-74.939,-74.939,-73.939,-73.939,-73.939,-75.939,-73.939,-70.939,-73.939,-74.939,-73.939,-73.939,-73.939,-74.939,-74.939,-72.939,-72.939,-72.939,-75.939,-74.939,-72.939,-72.939,-73.939,-75.939,-76.939,-77.939,-79.939,-77.939,-77.939,-76.939,-76.939,-77.939,-74.939,-74.939,-76.939,-75.939,-75.939,-76.939,-76.939,-75.939,-74.939,-72.939,-72.939,-75.939,-76.939,-76.939,-75.939,-73.939,-78.939,-80.939,-79.939,-78.939,-78.939,-80.939,-79.939,-78.939,-78.939,-80.939,-83.939,-81.939,-82.939,-85.939,-85.939,-84.939,-88.939,-90.939,-91.939,-89.939,-88.939,-84.939,-82.939,-81.939,-80.939,-79.939,-79.939,-80.939,-81.939,-79.939,-79.939,-80.939,-80.939,-78.939,-74.939,-75.939,-75.939,-73.939,-75.939,-78.939,-78.939,-75.939,-70.939,-76.939,-80.939,-79.939,-77.939,-76.939,-77.939,-76.939,-75.939,-78.939,-81.939,-79.939,-79.939,-79.939,-77.939,-77.939,-78.939,-79.939,-79.939,-78.939,-80.939,-83.939,-81.939,-80.939,-80.939,-78.939,-76.939,-75.939,-78.939,-82.939,-82.939,-82.939,-83.939,-82.939,-81.939,-80.939,-82.939,-83.939,-83.939,-81.939,-79.939,-78.939,-77.939,-80.939,-79.939,-77.939,-77.939,-79.939,-81.939,-80.939,-78.939,-79.939,-79.939,-79.939,-80.939,-81.939,-81.939,-82.939,-82.939,-84.939,-83.939,-83.939,-83.939,-84.939,-86.939,-84.939,-82.939,-81.939,-81.939,-85.939,-82.939,-82.939,-86.939,-84.939,-81.939,-80.939,-81.939,-83.939,-82.939,-81.939,-79.939,-83.939,-86.939,-87.939,-86.939,-84.939,-84.939,-85.939,-86.939,-82.939,-80.939,-79.939,-79.939,-78.939,-78.939,-76.939,-72.939,-72.939,-72.939,-74.939,-79.939,-83.939,-81.939,-78.939,-77.939,-75.939,-75.939,-78.939,-80.939,-81.939,-74.939,-71.939,-71.939,-72.939,-74.939,-76.939,-76.939,-77.939,-78.939,-79.939,-80.939,-78.939,-78.939,-79.939,-80.939,-79.939,-77.939,-78.939,-82.939,-83.939,-84.939,-83.939,-80.939,-77.939,-79.939,-79.939,-76.939,-78.939,-79.939,-76.939,-76.939,-77.939,-77.939,-79.939,-83.939,-81.939,-80.939,-81.939,-83.939,-85.939,-85.939,-85.939,-84.939,-84.939,-83.939,-80.939,-82.939,-84.939,-82.939,-81.939,-81.939,-82.939,-83.939,-84.939,-83.939,-83.939,-82.939,-82.939,-84.939,-83.939,-83.939,-86.939,-85.939,-84.939,-84.939,-84.939,-83.939,-83.939,-82.939,-83.939,-82.939,-81.939,-85.939,-86.939,-85.939,-87.939,-90.939,-89.939,-91.939,-92.939,-88.939,-87.939,-87.939,-89.939,-89.939,-87.939,-87.939,-86.939,-90.939,-90.939,-89.939,-91.939,-93.939,-92.939,-90.939,-89.939,-93.939,-95.939,-92.939,-94.939,-96.939,-98.939,-99.939,-99.939,-96.939,-92.939,-84.939,-75.939,-67.939,-70.939,-71.939,-73.939,-74.939,-75.939,-73.939,-69.939,-65.939,-59.939003,-56.939003,-54.939003,-41.939003,-36.939003,-38.939003,-48.939003,-55.939003,-58.939003,-60.939003,-61.939003,-58.939003,-55.939003,-52.939003,-49.939003,-45.939003,-42.939003,-45.939003,-50.939003,-57.939003,-60.939003,-61.939003,-63.939003,-63.939003,-61.939003,-59.939003,-57.939003,-59.939003,-60.939003,-61.939003,-65.939,-69.939,91.061,81.061,69.061,69.061,57.060997,32.060997,10.060997,2.060997,27.060997,26.060997,14.060997,9.060997,4.060997,0.06099701,-4.939003,-11.939003,-22.939003,-25.939003,-23.939003,-30.939003,-31.939003,-30.939003,-12.939003,7.060997,28.060997,29.060997,24.060997,21.060997,19.060997,17.060997,12.060997,6.060997,0.06099701,-6.939003,-14.939003,-25.939003,-35.939003,-41.939003,-50.939003,-59.939003,-72.939,-77.939,-80.939,-87.939,-94.939,-100.939,-101.939,-101.939,-101.939,-88.939,-69.939,-30.939003,-36.939003,-88.939,-95.939,-95.939,-90.939,-87.939,-84.939,-78.939,-70.939,-61.939003,-55.939003,-48.939003,-42.939003,-33.939003,-25.939003,-18.939003,-12.939003,-5.939003,4.060997,13.060997,17.060997,19.060997,22.060997,28.060997,33.060997,37.060997,35.060997,33.060997,36.060997,32.060997,26.060997,26.060997,22.060997,13.060997,11.060997,13.060997,22.060997,27.060997,32.060997,48.060997,35.060997,-7.939003,-55.939003,-89.939,-88.939,-95.939,-102.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-99.939,-98.939,-98.939,-97.939,-96.939,-94.939,-89.939,-81.939,-69.939,-59.939003,-57.939003,-52.939003,-42.939003,-31.939003,-23.939003,-22.939003,-12.939003,0.06099701,4.060997,10.060997,14.060997,21.060997,27.060997,32.060997,35.060997,36.060997,36.060997,37.060997,36.060997,37.060997,37.060997,39.060997,34.060997,28.060997,18.060997,7.060997,-3.939003,-14.939003,-25.939003,-37.939003,-48.939003,-58.939003,-68.939,-76.939,-82.939,-94.939,-85.939,-24.939003,51.060997,124.061,115.061,107.061,102.061,-11.939003,-98.939,-97.939,-97.939,-97.939,-96.939,-92.939,-84.939,-75.939,-67.939,-58.939003,-45.939003,-29.939003,-25.939003,-18.939003,-11.939003,-4.939003,4.060997,11.060997,18.060997,24.060997,30.060997,32.060997,32.060997,36.060997,39.060997,38.060997,38.060997,36.060997,32.060997,29.060997,25.060997,16.060997,7.060997,-0.939003,-12.939003,-24.939003,-36.939003,-46.939003,-50.939003,-61.939003,-72.939,-86.939,-96.939,-102.939,-101.939,-101.939,-100.939,-100.939,-99.939,-99.939,-98.939,-97.939,-84.939,-60.939003,-25.939003,-5.939003,10.060997,23.060997,35.060997,43.060997,25.060997,13.060997,6.060997,3.060997,0.06099701,-5.939003,-11.939003,-16.939003,-17.939003,-12.939003,-3.939003,-7.939003,-9.939003,-2.939003,-8.939003,-17.939003,-17.939003,-14.939003,-7.939003,-3.939003,-1.939003,-0.939003,1.060997,4.060997,16.060997,21.060997,21.060997,27.060997,34.060997,43.060997,50.060997,52.060997,25.060997,29.060997,63.060997,65.061,75.061,109.061,51.060997,-39.939003,-81.939,-99.939,-92.939,-89.939,-90.939,-98.939,-101.939,-101.939,-91.939,-78.939,-62.939003,-59.939003,-59.939003,-60.939003,-56.939003,-51.939003,-48.939003,-48.939003,-53.939003,-62.939003,-71.939,-72.939,-73.939,-72.939,-75.939,-76.939,-77.939,-73.939,-71.939,-75.939,-79.939,-83.939,-81.939,-79.939,-78.939,-81.939,-82.939,-79.939,-77.939,-75.939,-78.939,-80.939,-79.939,-81.939,-83.939,-84.939,-83.939,-82.939,-80.939,-79.939,-80.939,-80.939,-78.939,-70.939,-75.939,-85.939,-96.939,-94.939,-78.939,-38.939003,-0.939003,25.060997,12.060997,-15.939003,-44.939003,-62.939003,-70.939,-70.939,-69.939,-68.939,-70.939,-71.939,-67.939,-65.939,-66.939,-69.939,-71.939,-67.939,-66.939,-66.939,-69.939,-70.939,-69.939,-69.939,-69.939,-70.939,-68.939,-66.939,-68.939,-68.939,-65.939,-66.939,-68.939,-69.939,-67.939,-63.939003,-67.939,-69.939,-70.939,-70.939,-70.939,-70.939,-71.939,-73.939,-71.939,-72.939,-75.939,-74.939,-74.939,-75.939,-73.939,-70.939,-71.939,-74.939,-78.939,-75.939,-73.939,-75.939,-74.939,-71.939,-70.939,-72.939,-76.939,-72.939,-68.939,-67.939,-69.939,-74.939,-72.939,-70.939,-68.939,-71.939,-73.939,-69.939,-67.939,-66.939,-67.939,-68.939,-69.939,-71.939,-71.939,-70.939,-73.939,-78.939,-74.939,-72.939,-71.939,-70.939,-70.939,-69.939,-70.939,-72.939,-72.939,-72.939,-72.939,-74.939,-75.939,-77.939,-76.939,-75.939,-72.939,-71.939,-70.939,-73.939,-77.939,-78.939,-76.939,-73.939,-72.939,-73.939,-78.939,-77.939,-77.939,-75.939,-75.939,-74.939,-75.939,-75.939,-76.939,-77.939,-78.939,-79.939,-79.939,-77.939,-74.939,-73.939,-74.939,-72.939,-71.939,-71.939,-73.939,-76.939,-75.939,-75.939,-74.939,-75.939,-76.939,-77.939,-74.939,-71.939,-77.939,-77.939,-73.939,-71.939,-72.939,-74.939,-76.939,-77.939,-74.939,-73.939,-75.939,-74.939,-73.939,-73.939,-74.939,-75.939,-74.939,-75.939,-78.939,-79.939,-79.939,-76.939,-77.939,-79.939,-76.939,-75.939,-78.939,-77.939,-77.939,-75.939,-75.939,-75.939,-76.939,-74.939,-70.939,-73.939,-75.939,-76.939,-74.939,-72.939,-77.939,-79.939,-76.939,-77.939,-78.939,-79.939,-80.939,-80.939,-79.939,-79.939,-82.939,-82.939,-83.939,-87.939,-87.939,-85.939,-86.939,-87.939,-87.939,-86.939,-86.939,-84.939,-82.939,-80.939,-81.939,-81.939,-80.939,-81.939,-80.939,-75.939,-75.939,-76.939,-79.939,-79.939,-78.939,-82.939,-83.939,-78.939,-78.939,-79.939,-81.939,-80.939,-75.939,-79.939,-82.939,-80.939,-79.939,-78.939,-78.939,-78.939,-77.939,-80.939,-82.939,-77.939,-76.939,-76.939,-79.939,-80.939,-79.939,-78.939,-78.939,-78.939,-79.939,-81.939,-79.939,-79.939,-81.939,-77.939,-74.939,-76.939,-77.939,-79.939,-80.939,-81.939,-82.939,-80.939,-78.939,-78.939,-81.939,-84.939,-84.939,-82.939,-79.939,-77.939,-75.939,-77.939,-76.939,-74.939,-78.939,-82.939,-86.939,-84.939,-80.939,-79.939,-79.939,-80.939,-81.939,-81.939,-82.939,-81.939,-81.939,-83.939,-85.939,-88.939,-84.939,-82.939,-83.939,-85.939,-85.939,-83.939,-82.939,-83.939,-80.939,-81.939,-88.939,-84.939,-79.939,-78.939,-81.939,-86.939,-82.939,-81.939,-84.939,-83.939,-83.939,-82.939,-84.939,-87.939,-89.939,-88.939,-85.939,-80.939,-78.939,-80.939,-78.939,-74.939,-77.939,-76.939,-71.939,-71.939,-71.939,-72.939,-77.939,-83.939,-83.939,-81.939,-79.939,-77.939,-76.939,-80.939,-82.939,-82.939,-72.939,-67.939,-66.939,-69.939,-72.939,-75.939,-76.939,-77.939,-78.939,-78.939,-77.939,-76.939,-77.939,-83.939,-82.939,-78.939,-73.939,-73.939,-80.939,-82.939,-83.939,-85.939,-81.939,-77.939,-80.939,-80.939,-77.939,-78.939,-77.939,-73.939,-71.939,-73.939,-75.939,-76.939,-78.939,-79.939,-81.939,-83.939,-85.939,-86.939,-84.939,-83.939,-82.939,-83.939,-84.939,-82.939,-81.939,-80.939,-81.939,-82.939,-82.939,-82.939,-83.939,-86.939,-85.939,-82.939,-78.939,-78.939,-83.939,-82.939,-83.939,-87.939,-87.939,-86.939,-84.939,-83.939,-82.939,-83.939,-83.939,-84.939,-83.939,-82.939,-85.939,-87.939,-87.939,-88.939,-89.939,-90.939,-91.939,-91.939,-90.939,-89.939,-88.939,-88.939,-88.939,-88.939,-88.939,-88.939,-91.939,-91.939,-88.939,-92.939,-94.939,-93.939,-90.939,-87.939,-92.939,-94.939,-93.939,-96.939,-97.939,-96.939,-95.939,-94.939,-87.939,-83.939,-80.939,-73.939,-66.939,-61.939003,-64.939,-71.939,-70.939,-69.939,-70. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Unfortunately github truncated my gist. The full file is here: https://www.dropbox.com/s/ldjnyloxq2vpi2s/pytorch_memory_leak.py
I know it's large but it spins up and executes quickly :)