Skip to content

Instantly share code, notes, and snippets.

@Jc2k
Last active March 19, 2020 17:27
Show Gist options
  • Save Jc2k/c89e30fdc112b2329b12609bff8a9e07 to your computer and use it in GitHub Desktop.
Save Jc2k/c89e30fdc112b2329b12609bff8a9e07 to your computer and use it in GitHub Desktop.
homekit_controller fixes
#
# Copyright 2019 aiohomekit team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
from aiohomekit.crypto.chacha20poly1305 import (
chacha20_aead_decrypt,
chacha20_aead_encrypt,
)
from aiohomekit.exceptions import (
AccessoryDisconnectedError,
AccessoryNotFoundError,
AuthenticationError,
ConnectionError,
HomeKitException,
TimeoutError,
)
from aiohomekit.http import HttpContentTypes
from aiohomekit.http.response import HttpResponse
from aiohomekit.protocol import get_session_keys
from aiohomekit.protocol.tlv import TLV
from aiohomekit.zeroconf import async_find_device_ip_and_port
logger = logging.getLogger(__name__)
class InsecureHomeKitProtocol(asyncio.Protocol):
def __init__(self, connection):
self.connection = connection
self.host = ":".join((connection.host, str(connection.port)))
self.result_cbs = []
self.current_response = HttpResponse()
def connection_made(self, transport):
super().connection_made(transport)
self.transport = transport
def connection_lost(self, exception):
self.connection._connection_lost(exception)
async def send_bytes(self, payload):
if self.transport.is_closing():
# FIXME: It would be nice to try and wait for the reconnect in future.
# In that case we need to make sure we do it at a layer above send_bytes otherwise
# we might encrypt payloads with the last sessions keys then wait for a new connection
# to send them - and on that connection the keys would be different.
# Also need to make sure that the new connection has chance to pair-verify before
# queued writes can happy.
raise AccessoryDisconnectedError("Transport is closed")
self.transport.write(payload)
# We return a future so that our caller can block on a reply
# We can send many requests and dispatch the results in order
# Should mean we don't need locking around request/reply cycles
result = asyncio.Future()
self.result_cbs.append(result)
try:
return await asyncio.wait_for(result, 30)
except asyncio.TimeoutError:
self.transport.write_eof()
self.transport.close()
raise AccessoryDisconnectedError("Timeout while waiting for response")
def data_received(self, data):
while data:
data = self.current_response.parse(data)
if self.current_response.is_read_completely():
http_name = self.current_response.get_http_name().lower()
if http_name == "http":
next_callback = self.result_cbs.pop(0)
next_callback.set_result(self.current_response)
elif http_name == "event":
self.connection.event_received(self.current_response)
else:
raise RuntimeError("Unknown http type")
self.current_response = HttpResponse()
def eof_received(self):
self.close()
return False
def close(self):
# If the connection is closed then any pending callbacks will never
# fire, so set them to an error state.
while self.result_cbs:
result = self.result_cbs.pop(0)
result.set_exception(AccessoryDisconnectedError("Connection closed"))
class SecureHomeKitProtocol(InsecureHomeKitProtocol):
def __init__(self, connection, a2c_key, c2a_key):
super().__init__(connection)
self._incoming_buffer = bytearray()
self.c2a_counter = 0
self.a2c_counter = 0
self.a2c_key = a2c_key
self.c2a_key = c2a_key
async def send_bytes(self, payload):
buffer = b""
while len(payload) > 0:
current = payload[:1024]
payload = payload[1024:]
len_bytes = len(current).to_bytes(2, byteorder="little")
cnt_bytes = self.c2a_counter.to_bytes(8, byteorder="little")
self.c2a_counter += 1
data = chacha20_aead_encrypt(
len_bytes, self.c2a_key, cnt_bytes, bytes([0, 0, 0, 0]), current,
)
buffer += len_bytes + data
return await super().send_bytes(buffer)
def data_received(self, data):
"""
Called by asyncio when data is received from a TCP socket.
This just handles decryption of 1024 blocks and its them over to the underlying
InsecureHomeKitProtocol to handle HTTP unframing.
The blocks are expected to be in order - there is no protocol level support for
interleaving of HTTP messages.
"""
self._incoming_buffer += data
while len(self._incoming_buffer) >= 2:
block_length_bytes = self._incoming_buffer[:2]
block_length = int.from_bytes(block_length_bytes, "little")
exp_length = block_length + 18
if len(self._incoming_buffer) < exp_length:
# Not enough data yet
return
# Drop the length from the top of the buffer as we have already parsed it
del self._incoming_buffer[:2]
block = self._incoming_buffer[:block_length]
del self._incoming_buffer[:block_length]
tag = self._incoming_buffer[:16]
del self._incoming_buffer[:16]
decrypted = chacha20_aead_decrypt(
block_length_bytes,
self.a2c_key,
self.a2c_counter.to_bytes(8, byteorder="little"),
bytes([0, 0, 0, 0]),
block + tag,
)
if decrypted is False:
# FIXME: Does raising here drop the connection or do we call close on transport ourselves
raise RuntimeError("Could not decrypt block")
self.a2c_counter += 1
super().data_received(decrypted)
class HomeKitConnection:
def __init__(self, owner, host, port):
self.owner = owner
self.host = host
self.port = port
self.closing = False
self.closed = False
self._retry_interval = 0.5
self.transport = None
self.protocol = None
self._connector = None
self.is_secure = False
@property
def is_connected(self):
return self.transport and self.protocol and not self.closed
def start_connector(self):
if self._connector:
return
def done_callback(result):
self._connector = None
try:
result.result()
except asyncio.CancelledError:
pass
except Exception:
logger.exception("Unhandled error from connecter.")
self._connector = asyncio.ensure_future(self._reconnect())
self._connector.add_done_callback(done_callback)
async def ensure_connection(self):
if self.is_connected:
return
self.closing = False
self.start_connector()
await asyncio.shield(self._connector)
async def stop_connector(self):
if not self._connector:
return
self._connector.cancel()
await self._connector
self._connector = None
async def get(self, target):
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
"""
return await self.request(method="GET", target=target,)
async def get_json(self, target):
logger.debug("get_json req %s", target)
response = await self.get(target)
body = response.body.decode("utf-8")
logger.debug("get_json resp %s", body)
return json.loads(body)
async def put(self, target, body, content_type=HttpContentTypes.JSON):
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
"""
return await self.request(
method="PUT",
target=target,
headers=[("Content-Type", content_type), ("Content-Length", len(body))],
body=body,
)
async def put_json(self, target, body):
logger.debug("put_json req %s %s", target, body)
response = await self.put(
target,
json.dumps(body).encode("utf-8"),
content_type=HttpContentTypes.JSON,
)
if response.code == 204:
logger.debug("put_json: resp code NOT 204: %s", response.code)
return {}
try:
decoded = response.body.decode("utf-8")
except UnicodeDecodeError:
self.transport.close()
raise AccessoryDisconnectedError(
"Session closed after receiving non-utf8 response"
)
try:
parsed = json.loads(decoded)
except json.JSONDecodeError:
self.transport.close()
raise AccessoryDisconnectedError(
"Session closed after receiving malformed response from device"
)
logger.debug("put_json resp bosy: %s", parsed)
return parsed
async def post(self, target, body, content_type=HttpContentTypes.TLV):
"""
Sends a HTTP POST request to the current transport and returns an awaitable
that can be used to wait for a response.
"""
return await self.request(
method="POST",
target=target,
headers=[("Content-Type", content_type), ("Content-Length", len(body))],
body=body,
)
async def post_json(self, target, body):
logger.debug("post_json req %s %s", target, body)
response = await self.post(
target, json.dumps(body).encode("utf-8"), content_type=HttpContentTypes.TLV,
)
if response.code != 204:
logger.debug("post_json: resp code NOT 204: %s", response.code)
# FIXME: ...
pass
decoded = response.body.decode("utf-8")
if not decoded:
# FIXME: Verify this is correct
logger.debug("post_json: Decoded body is empty")
return {}
try:
parsed = json.loads(decoded)
except json.JSONDecodeError:
self.transport.close()
raise AccessoryDisconnectedError(
"Session closed after receiving malformed response from device"
)
logger.debug("post_json resp bosy: %s", parsed)
return parsed
async def post_tlv(self, target, body, expected=None):
logger.debug("post_tlv req %s %s", target, body)
response = await self.post(
target, TLV.encode_list(body), content_type=HttpContentTypes.TLV,
)
body = TLV.decode_bytes(response.body, expected=expected)
logger.debug("post_tlv resp: %s", body)
return body
async def request(self, method, target, headers=None, body=None):
"""
Sends a HTTP request to the current transport and returns an awaitable
that can be used to wait for the response.
This will automatically set the header.
:param method: A HTTP method, like 'GET' or 'POST'
:param target: A URI to call the method on
:param headers: a list of (header, value) tuples (optional)
:param body: The body of the request (optional)
"""
if not self.protocol:
raise AccessoryDisconnectedError(
"Connection lost before request could be sent"
)
buffer = []
buffer.append(
"{method} {target} HTTP/1.1".format(method=method.upper(), target=target,)
)
# WARNING: It is vital that a Host: header is present or some devices
# will reject the request.
buffer.append("Host: {host}".format(host=self.host))
if headers:
for (header, value) in headers:
buffer.append("{header}: {value}".format(header=header, value=value))
buffer.append("")
buffer.append("")
# WARNING: We use \r\n explicitly. \n is not enough for some.
request_bytes = "\r\n".join(buffer).encode("utf-8")
if body:
request_bytes += body
# WARNING: It is vital that each request is sent in one call
# Some devices are sensitive to unecrypted HTTP requests made in
# multiple packets.
# https://github.com/jlusiardi/homekit_python/issues/12
# https://github.com/jlusiardi/homekit_python/issues/16
logger.debug("raw req: %r", request_bytes)
resp = await self.protocol.send_bytes(request_bytes)
if resp.code >= 400 and resp.code <= 499:
logger.debug(f"Got HTTP error {resp.code} for {method} against {target}")
raise AccessoryDisconnectedError(f"Got HTTP error {resp.code} for {method} against {target}")
logger.debug("raw resp: %r", resp.body)
return resp
async def close(self):
"""
Close the connection transport.
"""
self.closing = True
await self.stop_connector()
if self.transport:
self.transport.close()
self.protocol = None
self.transport = None
self.is_secure = None
def _connection_lost(self, exception):
"""
Called by a Protocol instance when eof_received happens.
"""
logger.debug("Connection %r lost.", self)
if not self.closing:
self.start_connector()
if self.closing:
self.closed = True
self.transport = None
self.protocol = None
async def _connect_once(self):
loop = asyncio.get_event_loop()
logger.debug("Attempting connection to %s:%s", self.host, self.port)
try:
self.transport, self.protocol = await asyncio.wait_for(
loop.create_connection(
lambda: InsecureHomeKitProtocol(self), self.host, self.port
),
timeout=10,
)
except asyncio.TimeoutError:
raise TimeoutError("Timeout")
except OSError as e:
raise ConnectionError(str(e))
if self.owner:
await self.owner.connection_made(False)
async def _reconnect(self):
# FIXME: How to integrate discovery here?
# There is aiozeroconf but that doesn't work on Windows until python 3.9
# In HASS, zeroconf is a service provided by HASS itself and want to be able to
# leverage that instead.
interval = 0.5
while not self.closing:
try:
return await self._connect_once()
except AuthenticationError:
# Authentication errors should bubble up because auto-reconnect is unlikely to help
raise
except asyncio.CancelledError:
return
except HomeKitException:
logger.debug(
"Connecting to accessory failed. Retrying in %i seconds", interval
)
except Exception:
logger.exception(
"Unexpected error whilst trying to connect to accessory. Will retry."
)
interval = min(60, 1.5 * interval)
await asyncio.sleep(interval)
def event_received(self, event):
logger.debug("EVENT: %s", event)
if not self.owner:
return
# FIXME: Should drop the connection if can't parse the event?
decoded = event.body.decode("utf-8")
if not decoded:
return
try:
parsed = json.loads(decoded)
except json.JSONDecodeError:
return
self.owner.event_received(parsed)
def __repr__(self):
return "HomeKitConnection(host=%r, port=%r)" % (self.host, self.port)
class SecureHomeKitConnection(HomeKitConnection):
def __init__(self, owner, pairing_data):
super().__init__(
owner, pairing_data["AccessoryIP"], pairing_data["AccessoryPort"],
)
self.pairing_data = pairing_data
@property
def is_connected(self):
return super().is_connected and self.is_secure
async def _connect_once(self):
self.is_secure = False
try:
self.host, self.port = await async_find_device_ip_and_port(
self.pairing_data["AccessoryPairingID"]
)
except AccessoryNotFoundError:
pass
await super()._connect_once()
state_machine = get_session_keys(self.pairing_data)
request, expected = state_machine.send(None)
while True:
try:
response = await self.post_tlv(
"/pair-verify", body=request, expected=expected,
)
request, expected = state_machine.send(response)
except StopIteration as result:
# If the state machine raises a StopIteration then we have session keys
c2a_key, a2c_key = result.value
break
# Secure session has been negotiated - switch protocol so all future messages are encrypted
self.protocol = SecureHomeKitProtocol(self, a2c_key, c2a_key,)
self.transport.set_protocol(self.protocol)
self.protocol.connection_made(self.transport)
self.is_secure = True
logger.debug("Secure connection to %s:%s established", self.host, self.port)
if self.owner:
await self.owner.connection_made(True)
#
# Copyright 2019 aiohomekit team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import Union
from aiohomekit.exceptions import HttpException
logger = logging.getLogger(__name__)
class HttpResponse(object):
STATE_PRE_STATUS = 0
STATE_HEADERS = 1
STATE_BODY = 2
STATE_DONE = 3
def __init__(self) -> None:
self._state = HttpResponse.STATE_PRE_STATUS
self._raw_response = bytearray()
self._is_ready = False
self._is_chunked = False
self._had_empty_chunk = False
self._content_length = -1
self.version = None
self.code = None
self.reason = None
self.headers = []
self.body = bytearray()
def parse(self, part: Union[bytearray, bytes]) -> bytearray:
self._raw_response += part
pos = self._raw_response.find(b"\r\n")
while pos != -1:
line = self._raw_response[:pos]
self._raw_response = self._raw_response[pos + 2 :]
if self._state == HttpResponse.STATE_PRE_STATUS:
# parse status line
line = line.split(b" ", 2)
if len(line) != 3:
raise HttpException("Malformed status line.")
self.version = line[0].decode()
self.code = int(line[1])
self.reason = line[2].decode()
self._state = HttpResponse.STATE_HEADERS
elif self._state == HttpResponse.STATE_HEADERS and line == b"":
# this is the empty line after the headers
self._state = HttpResponse.STATE_BODY
elif self._state == HttpResponse.STATE_HEADERS:
# parse a header line
line = line.split(b":", 1)
name = line[0].decode()
value = line[1].decode().strip()
if name == "Transfer-Encoding":
if value == "chunked":
self._is_chunked = True
elif name == "Content-Length":
self._content_length = int(value)
self.headers.append((name, value))
elif self._state == HttpResponse.STATE_BODY:
if self._is_chunked:
length = int(line, 16)
if length + 2 > len(self._raw_response):
self._raw_response = line + b"\r\n" + self._raw_response
# the remaining bytes in raw response are not sufficient. bail out and wait for an other call.
break
if length == 0:
self._had_empty_chunk = True
self._state = HttpResponse.STATE_DONE
self._raw_response = self._raw_response[length + 2 :]
else:
line = self._raw_response[:length]
self.body += line
self._raw_response = self._raw_response[length + 2 :]
if self._content_length > -1:
self.body += self._raw_response
self._raw_response = bytearray()
else:
raise HttpException("Unknown parser state")
pos = self._raw_response.find(b"\r\n")
if self._state == HttpResponse.STATE_BODY and self._content_length > 0:
remaining = self._content_length - len(self.body)
self.body += self._raw_response[:remaining]
self._raw_response = self._raw_response[remaining:]
if self.is_read_completely():
# Whatever is left in the buffer is part of the next request
if len(self._raw_response) > 0:
logger.debug(
"Bytes left in buffer after parsing packet: %r", self._raw_response
)
return self._raw_response
return bytearray()
def read(self):
"""
Returns the body of the response.
:return: The read body or None if no body content was read yet
"""
return self.body
def is_read_completely(self) -> bool:
if self._is_chunked:
return self._had_empty_chunk
if self._state < HttpResponse.STATE_BODY:
return False
if self._content_length != -1:
return len(self.body) == self._content_length
return True
def get_http_name(self) -> str:
"""
Returns the HTTP name (e.g. HTTP or EVENT).
:return: The name or None if the status line was not yet read
"""
if self.version is not None:
return self.version.split("/")[0]
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment