Skip to content

Instantly share code, notes, and snippets.

@Jc2k
Last active March 24, 2020 17:12
Show Gist options
  • Save Jc2k/11c2a53ebb1de1810655bb8758dbb132 to your computer and use it in GitHub Desktop.
Save Jc2k/11c2a53ebb1de1810655bb8758dbb132 to your computer and use it in GitHub Desktop.
Extra instrumentation
#
# Copyright 2019 aiohomekit team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
from aiohomekit.crypto.chacha20poly1305 import (
chacha20_aead_decrypt,
chacha20_aead_encrypt,
)
from aiohomekit.exceptions import (
AccessoryDisconnectedError,
AccessoryNotFoundError,
AuthenticationError,
ConnectionError,
HomeKitException,
HttpErrorResponse,
TimeoutError,
)
from aiohomekit.http import HttpContentTypes
from aiohomekit.http.response import HttpResponse
from aiohomekit.protocol import get_session_keys
from aiohomekit.protocol.tlv import TLV
from aiohomekit.zeroconf import async_find_device_ip_and_port
logger = logging.getLogger(__name__)
class InsecureHomeKitProtocol(asyncio.Protocol):
def __init__(self, connection):
self.connection = connection
self.host = ":".join((connection.host, str(connection.port)))
self.result_cbs = []
self.current_response = HttpResponse()
def connection_made(self, transport):
super().connection_made(transport)
self.transport = transport
def connection_lost(self, exception):
self.connection._connection_lost(exception)
async def send_bytes(self, payload):
if self.transport.is_closing():
# FIXME: It would be nice to try and wait for the reconnect in future.
# In that case we need to make sure we do it at a layer above send_bytes otherwise
# we might encrypt payloads with the last sessions keys then wait for a new connection
# to send them - and on that connection the keys would be different.
# Also need to make sure that the new connection has chance to pair-verify before
# queued writes can happy.
raise AccessoryDisconnectedError("Transport is closed")
logger.debug("INSECURE write: %s", payload)
self.transport.write(payload)
# We return a future so that our caller can block on a reply
# We can send many requests and dispatch the results in order
# Should mean we don't need locking around request/reply cycles
result = asyncio.Future()
self.result_cbs.append(result)
try:
return await asyncio.wait_for(result, 30)
except asyncio.TimeoutError:
self.transport.write_eof()
self.transport.close()
raise AccessoryDisconnectedError("Timeout while waiting for response")
def data_received(self, data):
logger.debug("INSECURE receive: %s", data)
while data:
data = self.current_response.parse(data)
if self.current_response.is_read_completely():
http_name = self.current_response.get_http_name().lower()
if http_name == "http":
next_callback = self.result_cbs.pop(0)
next_callback.set_result(self.current_response)
elif http_name == "event":
self.connection.event_received(self.current_response)
else:
raise RuntimeError("Unknown http type")
self.current_response = HttpResponse()
def eof_received(self):
self.close()
return False
def close(self):
# If the connection is closed then any pending callbacks will never
# fire, so set them to an error state.
while self.result_cbs:
result = self.result_cbs.pop(0)
result.set_exception(AccessoryDisconnectedError("Connection closed"))
class SecureHomeKitProtocol(InsecureHomeKitProtocol):
def __init__(self, connection, a2c_key, c2a_key):
super().__init__(connection)
self._incoming_buffer = bytearray()
self.c2a_counter = 0
self.a2c_counter = 0
self.a2c_key = a2c_key
self.c2a_key = c2a_key
async def send_bytes(self, payload):
buffer = b""
logger.debug("SECURE send_bytes %s", payload)
while len(payload) > 0:
current = payload[:1024]
payload = payload[1024:]
len_bytes = len(current).to_bytes(2, byteorder="little")
cnt_bytes = self.c2a_counter.to_bytes(8, byteorder="little")
self.c2a_counter += 1
data = chacha20_aead_encrypt(
len_bytes, self.c2a_key, cnt_bytes, bytes([0, 0, 0, 0]), current,
)
buffer += len_bytes + data
return await super().send_bytes(buffer)
def data_received(self, data):
"""
Called by asyncio when data is received from a TCP socket.
This just handles decryption of 1024 blocks and its them over to the underlying
InsecureHomeKitProtocol to handle HTTP unframing.
The blocks are expected to be in order - there is no protocol level support for
interleaving of HTTP messages.
"""
logger.debug("SECURE receive: %s", data)
self._incoming_buffer += data
while len(self._incoming_buffer) >= 2:
block_length_bytes = self._incoming_buffer[:2]
block_length = int.from_bytes(block_length_bytes, "little")
exp_length = block_length + 18
if len(self._incoming_buffer) < exp_length:
# Not enough data yet
return
# Drop the length from the top of the buffer as we have already parsed it
del self._incoming_buffer[:2]
block = self._incoming_buffer[:block_length]
del self._incoming_buffer[:block_length]
tag = self._incoming_buffer[:16]
del self._incoming_buffer[:16]
decrypted = chacha20_aead_decrypt(
block_length_bytes,
self.a2c_key,
self.a2c_counter.to_bytes(8, byteorder="little"),
bytes([0, 0, 0, 0]),
block + tag,
)
if decrypted is False:
# FIXME: Does raising here drop the connection or do we call close on transport ourselves
raise RuntimeError("Could not decrypt block")
self.a2c_counter += 1
super().data_received(decrypted)
class HomeKitConnection:
def __init__(self, owner, host, port):
self.owner = owner
self.host = host
self.port = port
self.closing = False
self.closed = False
self._retry_interval = 0.5
self.transport = None
self.protocol = None
self._connector = None
self.is_secure = False
@property
def is_connected(self):
return self.transport and self.protocol and not self.closed
def start_connector(self):
if self._connector:
return
def done_callback(result):
self._connector = None
try:
result.result()
except asyncio.CancelledError:
pass
except Exception:
logger.exception("Unhandled error from connecter.")
self._connector = asyncio.ensure_future(self._reconnect())
self._connector.add_done_callback(done_callback)
async def ensure_connection(self):
if self.is_connected:
return
self.closing = False
self.start_connector()
await asyncio.shield(self._connector)
async def stop_connector(self):
if not self._connector:
return
self._connector.cancel()
await self._connector
self._connector = None
async def get(self, target):
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
"""
return await self.request(method="GET", target=target,)
async def get_json(self, target):
response = await self.get(target)
body = response.body.decode("utf-8")
return json.loads(body)
async def put(self, target, body, content_type=HttpContentTypes.JSON):
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
"""
return await self.request(
method="PUT",
target=target,
headers=[("Content-Type", content_type), ("Content-Length", len(body))],
body=body,
)
async def put_json(self, target, body):
response = await self.put(
target,
json.dumps(body).encode("utf-8"),
content_type=HttpContentTypes.JSON,
)
if response.code == 204:
return {}
try:
decoded = response.body.decode("utf-8")
except UnicodeDecodeError:
self.transport.close()
raise AccessoryDisconnectedError(
"Session closed after receiving non-utf8 response"
)
try:
parsed = json.loads(decoded)
except json.JSONDecodeError:
self.transport.close()
raise AccessoryDisconnectedError(
"Session closed after receiving malformed response from device"
)
return parsed
async def post(self, target, body, content_type=HttpContentTypes.TLV):
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
"""
return await self.request(
method="POST",
target=target,
headers=[("Content-Type", content_type), ("Content-Length", len(body))],
body=body,
)
async def post_json(self, target, body):
response = await self.post(
target, json.dumps(body).encode("utf-8"), content_type=HttpContentTypes.TLV,
)
if response.code != 204:
# FIXME: ...
pass
decoded = response.body.decode("utf-8")
if not decoded:
# FIXME: Verify this is correct
return {}
try:
parsed = json.loads(decoded)
except json.JSONDecodeError:
self.transport.close()
raise AccessoryDisconnectedError(
"Session closed after receiving malformed response from device"
)
return parsed
async def post_tlv(self, target, body, expected=None):
try:
response = await self.post(
target, TLV.encode_list(body), content_type=HttpContentTypes.TLV,
)
except HttpErrorResponse as e:
self.transport.close()
response = e.response
body = TLV.decode_bytes(response.body, expected=expected)
return body
async def request(self, method, target, headers=None, body=None):
"""
Sends a HTTP request to the current transport and returns an awaitable
that can be used to wait for the response.
This will automatically set the header.
:param method: A HTTP method, like 'GET' or 'POST'
:param target: A URI to call the method on
:param headers: a list of (header, value) tuples (optional)
:param body: The body of the request (optional)
"""
if not self.protocol:
raise AccessoryDisconnectedError(
"Connection lost before request could be sent"
)
buffer = []
buffer.append(
"{method} {target} HTTP/1.1".format(method=method.upper(), target=target,)
)
# WARNING: It is vital that a Host: header is present or some devices
# will reject the request.
buffer.append("Host: {host}".format(host=self.host))
if headers:
for (header, value) in headers:
buffer.append("{header}: {value}".format(header=header, value=value))
buffer.append("")
buffer.append("")
# WARNING: We use \r\n explicitly. \n is not enough for some.
request_bytes = "\r\n".join(buffer).encode("utf-8")
if body:
request_bytes += body
# WARNING: It is vital that each request is sent in one call
# Some devices are sensitive to unecrypted HTTP requests made in
# multiple packets.
# https://github.com/jlusiardi/homekit_python/issues/12
# https://github.com/jlusiardi/homekit_python/issues/16
logger.debug("%s: raw request: %r", self.host, request_bytes)
resp = await self.protocol.send_bytes(request_bytes)
if resp.code >= 400 and resp.code <= 499:
logger.debug(f"Got HTTP error {resp.code} for {method} against {target}")
raise HttpErrorResponse(
f"Got HTTP error {resp.code} for {method} against {target}",
response=resp,
)
logger.debug("%s: raw response: %r", self.host, resp.body)
return resp
async def close(self):
"""
Close the connection transport.
"""
self.closing = True
await self.stop_connector()
if self.transport:
self.transport.close()
self.protocol = None
self.transport = None
self.is_secure = None
def _connection_lost(self, exception):
"""
Called by a Protocol instance when eof_received happens.
"""
logger.debug("Connection %r lost.", self)
if not self.closing:
self.start_connector()
if self.closing:
self.closed = True
self.transport = None
self.protocol = None
async def _connect_once(self):
loop = asyncio.get_event_loop()
logger.debug("Attempting connection to %s:%s", self.host, self.port)
try:
self.transport, self.protocol = await asyncio.wait_for(
loop.create_connection(
lambda: InsecureHomeKitProtocol(self), self.host, self.port
),
timeout=10,
)
except asyncio.TimeoutError:
raise TimeoutError("Timeout")
except OSError as e:
raise ConnectionError(str(e))
if self.owner:
await self.owner.connection_made(False)
async def _reconnect(self):
# FIXME: How to integrate discovery here?
# There is aiozeroconf but that doesn't work on Windows until python 3.9
# In HASS, zeroconf is a service provided by HASS itself and want to be able to
# leverage that instead.
interval = 0.5
while not self.closing:
try:
return await self._connect_once()
except AuthenticationError:
# Authentication errors should bubble up because auto-reconnect is unlikely to help
raise
except asyncio.CancelledError:
return
except HomeKitException:
logger.debug(
"Connecting to accessory failed. Retrying in %i seconds", interval
)
except Exception:
logger.exception(
"Unexpected error whilst trying to connect to accessory. Will retry."
)
interval = min(60, 1.5 * interval)
await asyncio.sleep(interval)
def event_received(self, event):
if not self.owner:
return
# FIXME: Should drop the connection if can't parse the event?
decoded = event.body.decode("utf-8")
if not decoded:
return
try:
parsed = json.loads(decoded)
except json.JSONDecodeError:
return
self.owner.event_received(parsed)
def __repr__(self):
return "HomeKitConnection(host=%r, port=%r)" % (self.host, self.port)
class SecureHomeKitConnection(HomeKitConnection):
def __init__(self, owner, pairing_data):
super().__init__(
owner, pairing_data["AccessoryIP"], pairing_data["AccessoryPort"],
)
self.pairing_data = pairing_data
@property
def is_connected(self):
return super().is_connected and self.is_secure
async def _connect_once(self):
self.is_secure = False
try:
self.host, self.port = await async_find_device_ip_and_port(
self.pairing_data["AccessoryPairingID"]
)
except AccessoryNotFoundError:
pass
await super()._connect_once()
state_machine = get_session_keys(self.pairing_data)
request, expected = state_machine.send(None)
while True:
try:
response = await self.post_tlv(
"/pair-verify", body=request, expected=expected,
)
request, expected = state_machine.send(response)
except StopIteration as result:
# If the state machine raises a StopIteration then we have session keys
c2a_key, a2c_key = result.value
break
# Secure session has been negotiated - switch protocol so all future messages are encrypted
self.protocol = SecureHomeKitProtocol(self, a2c_key, c2a_key,)
self.transport.set_protocol(self.protocol)
self.protocol.connection_made(self.transport)
self.is_secure = True
logger.debug("Secure connection to %s:%s established", self.host, self.port)
if self.owner:
await self.owner.connection_made(True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment