Last active
March 19, 2020 17:27
-
-
Save Jc2k/c89e30fdc112b2329b12609bff8a9e07 to your computer and use it in GitHub Desktop.
homekit_controller fixes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Copyright 2019 aiohomekit team | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import asyncio | |
import json | |
import logging | |
from aiohomekit.crypto.chacha20poly1305 import ( | |
chacha20_aead_decrypt, | |
chacha20_aead_encrypt, | |
) | |
from aiohomekit.exceptions import ( | |
AccessoryDisconnectedError, | |
AccessoryNotFoundError, | |
AuthenticationError, | |
ConnectionError, | |
HomeKitException, | |
TimeoutError, | |
) | |
from aiohomekit.http import HttpContentTypes | |
from aiohomekit.http.response import HttpResponse | |
from aiohomekit.protocol import get_session_keys | |
from aiohomekit.protocol.tlv import TLV | |
from aiohomekit.zeroconf import async_find_device_ip_and_port | |
logger = logging.getLogger(__name__) | |
class InsecureHomeKitProtocol(asyncio.Protocol): | |
def __init__(self, connection): | |
self.connection = connection | |
self.host = ":".join((connection.host, str(connection.port))) | |
self.result_cbs = [] | |
self.current_response = HttpResponse() | |
def connection_made(self, transport): | |
super().connection_made(transport) | |
self.transport = transport | |
def connection_lost(self, exception): | |
self.connection._connection_lost(exception) | |
async def send_bytes(self, payload): | |
if self.transport.is_closing(): | |
# FIXME: It would be nice to try and wait for the reconnect in future. | |
# In that case we need to make sure we do it at a layer above send_bytes otherwise | |
# we might encrypt payloads with the last sessions keys then wait for a new connection | |
# to send them - and on that connection the keys would be different. | |
# Also need to make sure that the new connection has chance to pair-verify before | |
# queued writes can happy. | |
raise AccessoryDisconnectedError("Transport is closed") | |
self.transport.write(payload) | |
# We return a future so that our caller can block on a reply | |
# We can send many requests and dispatch the results in order | |
# Should mean we don't need locking around request/reply cycles | |
result = asyncio.Future() | |
self.result_cbs.append(result) | |
try: | |
return await asyncio.wait_for(result, 30) | |
except asyncio.TimeoutError: | |
self.transport.write_eof() | |
self.transport.close() | |
raise AccessoryDisconnectedError("Timeout while waiting for response") | |
def data_received(self, data): | |
while data: | |
data = self.current_response.parse(data) | |
if self.current_response.is_read_completely(): | |
http_name = self.current_response.get_http_name().lower() | |
if http_name == "http": | |
next_callback = self.result_cbs.pop(0) | |
next_callback.set_result(self.current_response) | |
elif http_name == "event": | |
self.connection.event_received(self.current_response) | |
else: | |
raise RuntimeError("Unknown http type") | |
self.current_response = HttpResponse() | |
def eof_received(self): | |
self.close() | |
return False | |
def close(self): | |
# If the connection is closed then any pending callbacks will never | |
# fire, so set them to an error state. | |
while self.result_cbs: | |
result = self.result_cbs.pop(0) | |
result.set_exception(AccessoryDisconnectedError("Connection closed")) | |
class SecureHomeKitProtocol(InsecureHomeKitProtocol): | |
def __init__(self, connection, a2c_key, c2a_key): | |
super().__init__(connection) | |
self._incoming_buffer = bytearray() | |
self.c2a_counter = 0 | |
self.a2c_counter = 0 | |
self.a2c_key = a2c_key | |
self.c2a_key = c2a_key | |
async def send_bytes(self, payload): | |
buffer = b"" | |
while len(payload) > 0: | |
current = payload[:1024] | |
payload = payload[1024:] | |
len_bytes = len(current).to_bytes(2, byteorder="little") | |
cnt_bytes = self.c2a_counter.to_bytes(8, byteorder="little") | |
self.c2a_counter += 1 | |
data = chacha20_aead_encrypt( | |
len_bytes, self.c2a_key, cnt_bytes, bytes([0, 0, 0, 0]), current, | |
) | |
buffer += len_bytes + data | |
return await super().send_bytes(buffer) | |
def data_received(self, data): | |
""" | |
Called by asyncio when data is received from a TCP socket. | |
This just handles decryption of 1024 blocks and its them over to the underlying | |
InsecureHomeKitProtocol to handle HTTP unframing. | |
The blocks are expected to be in order - there is no protocol level support for | |
interleaving of HTTP messages. | |
""" | |
self._incoming_buffer += data | |
while len(self._incoming_buffer) >= 2: | |
block_length_bytes = self._incoming_buffer[:2] | |
block_length = int.from_bytes(block_length_bytes, "little") | |
exp_length = block_length + 18 | |
if len(self._incoming_buffer) < exp_length: | |
# Not enough data yet | |
return | |
# Drop the length from the top of the buffer as we have already parsed it | |
del self._incoming_buffer[:2] | |
block = self._incoming_buffer[:block_length] | |
del self._incoming_buffer[:block_length] | |
tag = self._incoming_buffer[:16] | |
del self._incoming_buffer[:16] | |
decrypted = chacha20_aead_decrypt( | |
block_length_bytes, | |
self.a2c_key, | |
self.a2c_counter.to_bytes(8, byteorder="little"), | |
bytes([0, 0, 0, 0]), | |
block + tag, | |
) | |
if decrypted is False: | |
# FIXME: Does raising here drop the connection or do we call close on transport ourselves | |
raise RuntimeError("Could not decrypt block") | |
self.a2c_counter += 1 | |
super().data_received(decrypted) | |
class HomeKitConnection: | |
def __init__(self, owner, host, port): | |
self.owner = owner | |
self.host = host | |
self.port = port | |
self.closing = False | |
self.closed = False | |
self._retry_interval = 0.5 | |
self.transport = None | |
self.protocol = None | |
self._connector = None | |
self.is_secure = False | |
@property | |
def is_connected(self): | |
return self.transport and self.protocol and not self.closed | |
def start_connector(self): | |
if self._connector: | |
return | |
def done_callback(result): | |
self._connector = None | |
try: | |
result.result() | |
except asyncio.CancelledError: | |
pass | |
except Exception: | |
logger.exception("Unhandled error from connecter.") | |
self._connector = asyncio.ensure_future(self._reconnect()) | |
self._connector.add_done_callback(done_callback) | |
async def ensure_connection(self): | |
if self.is_connected: | |
return | |
self.closing = False | |
self.start_connector() | |
await asyncio.shield(self._connector) | |
async def stop_connector(self): | |
if not self._connector: | |
return | |
self._connector.cancel() | |
await self._connector | |
self._connector = None | |
async def get(self, target): | |
""" | |
Sends a HTTP POST request to the current transport and returns an awaitable | |
that can be used to wait for a response. | |
""" | |
return await self.request(method="GET", target=target,) | |
async def get_json(self, target): | |
logger.debug("get_json req %s", target) | |
response = await self.get(target) | |
body = response.body.decode("utf-8") | |
logger.debug("get_json resp %s", body) | |
return json.loads(body) | |
async def put(self, target, body, content_type=HttpContentTypes.JSON): | |
""" | |
Sends a HTTP POST request to the current transport and returns an awaitable | |
that can be used to wait for a response. | |
""" | |
return await self.request( | |
method="PUT", | |
target=target, | |
headers=[("Content-Type", content_type), ("Content-Length", len(body))], | |
body=body, | |
) | |
async def put_json(self, target, body): | |
logger.debug("put_json req %s %s", target, body) | |
response = await self.put( | |
target, | |
json.dumps(body).encode("utf-8"), | |
content_type=HttpContentTypes.JSON, | |
) | |
if response.code == 204: | |
logger.debug("put_json: resp code NOT 204: %s", response.code) | |
return {} | |
try: | |
decoded = response.body.decode("utf-8") | |
except UnicodeDecodeError: | |
self.transport.close() | |
raise AccessoryDisconnectedError( | |
"Session closed after receiving non-utf8 response" | |
) | |
try: | |
parsed = json.loads(decoded) | |
except json.JSONDecodeError: | |
self.transport.close() | |
raise AccessoryDisconnectedError( | |
"Session closed after receiving malformed response from device" | |
) | |
logger.debug("put_json resp bosy: %s", parsed) | |
return parsed | |
async def post(self, target, body, content_type=HttpContentTypes.TLV): | |
""" | |
Sends a HTTP POST request to the current transport and returns an awaitable | |
that can be used to wait for a response. | |
""" | |
return await self.request( | |
method="POST", | |
target=target, | |
headers=[("Content-Type", content_type), ("Content-Length", len(body))], | |
body=body, | |
) | |
async def post_json(self, target, body): | |
logger.debug("post_json req %s %s", target, body) | |
response = await self.post( | |
target, json.dumps(body).encode("utf-8"), content_type=HttpContentTypes.TLV, | |
) | |
if response.code != 204: | |
logger.debug("post_json: resp code NOT 204: %s", response.code) | |
# FIXME: ... | |
pass | |
decoded = response.body.decode("utf-8") | |
if not decoded: | |
# FIXME: Verify this is correct | |
logger.debug("post_json: Decoded body is empty") | |
return {} | |
try: | |
parsed = json.loads(decoded) | |
except json.JSONDecodeError: | |
self.transport.close() | |
raise AccessoryDisconnectedError( | |
"Session closed after receiving malformed response from device" | |
) | |
logger.debug("post_json resp bosy: %s", parsed) | |
return parsed | |
async def post_tlv(self, target, body, expected=None): | |
logger.debug("post_tlv req %s %s", target, body) | |
response = await self.post( | |
target, TLV.encode_list(body), content_type=HttpContentTypes.TLV, | |
) | |
body = TLV.decode_bytes(response.body, expected=expected) | |
logger.debug("post_tlv resp: %s", body) | |
return body | |
async def request(self, method, target, headers=None, body=None): | |
""" | |
Sends a HTTP request to the current transport and returns an awaitable | |
that can be used to wait for the response. | |
This will automatically set the header. | |
:param method: A HTTP method, like 'GET' or 'POST' | |
:param target: A URI to call the method on | |
:param headers: a list of (header, value) tuples (optional) | |
:param body: The body of the request (optional) | |
""" | |
if not self.protocol: | |
raise AccessoryDisconnectedError( | |
"Connection lost before request could be sent" | |
) | |
buffer = [] | |
buffer.append( | |
"{method} {target} HTTP/1.1".format(method=method.upper(), target=target,) | |
) | |
# WARNING: It is vital that a Host: header is present or some devices | |
# will reject the request. | |
buffer.append("Host: {host}".format(host=self.host)) | |
if headers: | |
for (header, value) in headers: | |
buffer.append("{header}: {value}".format(header=header, value=value)) | |
buffer.append("") | |
buffer.append("") | |
# WARNING: We use \r\n explicitly. \n is not enough for some. | |
request_bytes = "\r\n".join(buffer).encode("utf-8") | |
if body: | |
request_bytes += body | |
# WARNING: It is vital that each request is sent in one call | |
# Some devices are sensitive to unecrypted HTTP requests made in | |
# multiple packets. | |
# https://github.com/jlusiardi/homekit_python/issues/12 | |
# https://github.com/jlusiardi/homekit_python/issues/16 | |
logger.debug("raw req: %r", request_bytes) | |
resp = await self.protocol.send_bytes(request_bytes) | |
if resp.code >= 400 and resp.code <= 499: | |
logger.debug(f"Got HTTP error {resp.code} for {method} against {target}") | |
raise AccessoryDisconnectedError(f"Got HTTP error {resp.code} for {method} against {target}") | |
logger.debug("raw resp: %r", resp.body) | |
return resp | |
async def close(self): | |
""" | |
Close the connection transport. | |
""" | |
self.closing = True | |
await self.stop_connector() | |
if self.transport: | |
self.transport.close() | |
self.protocol = None | |
self.transport = None | |
self.is_secure = None | |
def _connection_lost(self, exception): | |
""" | |
Called by a Protocol instance when eof_received happens. | |
""" | |
logger.debug("Connection %r lost.", self) | |
if not self.closing: | |
self.start_connector() | |
if self.closing: | |
self.closed = True | |
self.transport = None | |
self.protocol = None | |
async def _connect_once(self): | |
loop = asyncio.get_event_loop() | |
logger.debug("Attempting connection to %s:%s", self.host, self.port) | |
try: | |
self.transport, self.protocol = await asyncio.wait_for( | |
loop.create_connection( | |
lambda: InsecureHomeKitProtocol(self), self.host, self.port | |
), | |
timeout=10, | |
) | |
except asyncio.TimeoutError: | |
raise TimeoutError("Timeout") | |
except OSError as e: | |
raise ConnectionError(str(e)) | |
if self.owner: | |
await self.owner.connection_made(False) | |
async def _reconnect(self): | |
# FIXME: How to integrate discovery here? | |
# There is aiozeroconf but that doesn't work on Windows until python 3.9 | |
# In HASS, zeroconf is a service provided by HASS itself and want to be able to | |
# leverage that instead. | |
interval = 0.5 | |
while not self.closing: | |
try: | |
return await self._connect_once() | |
except AuthenticationError: | |
# Authentication errors should bubble up because auto-reconnect is unlikely to help | |
raise | |
except asyncio.CancelledError: | |
return | |
except HomeKitException: | |
logger.debug( | |
"Connecting to accessory failed. Retrying in %i seconds", interval | |
) | |
except Exception: | |
logger.exception( | |
"Unexpected error whilst trying to connect to accessory. Will retry." | |
) | |
interval = min(60, 1.5 * interval) | |
await asyncio.sleep(interval) | |
def event_received(self, event): | |
logger.debug("EVENT: %s", event) | |
if not self.owner: | |
return | |
# FIXME: Should drop the connection if can't parse the event? | |
decoded = event.body.decode("utf-8") | |
if not decoded: | |
return | |
try: | |
parsed = json.loads(decoded) | |
except json.JSONDecodeError: | |
return | |
self.owner.event_received(parsed) | |
def __repr__(self): | |
return "HomeKitConnection(host=%r, port=%r)" % (self.host, self.port) | |
class SecureHomeKitConnection(HomeKitConnection): | |
def __init__(self, owner, pairing_data): | |
super().__init__( | |
owner, pairing_data["AccessoryIP"], pairing_data["AccessoryPort"], | |
) | |
self.pairing_data = pairing_data | |
@property | |
def is_connected(self): | |
return super().is_connected and self.is_secure | |
async def _connect_once(self): | |
self.is_secure = False | |
try: | |
self.host, self.port = await async_find_device_ip_and_port( | |
self.pairing_data["AccessoryPairingID"] | |
) | |
except AccessoryNotFoundError: | |
pass | |
await super()._connect_once() | |
state_machine = get_session_keys(self.pairing_data) | |
request, expected = state_machine.send(None) | |
while True: | |
try: | |
response = await self.post_tlv( | |
"/pair-verify", body=request, expected=expected, | |
) | |
request, expected = state_machine.send(response) | |
except StopIteration as result: | |
# If the state machine raises a StopIteration then we have session keys | |
c2a_key, a2c_key = result.value | |
break | |
# Secure session has been negotiated - switch protocol so all future messages are encrypted | |
self.protocol = SecureHomeKitProtocol(self, a2c_key, c2a_key,) | |
self.transport.set_protocol(self.protocol) | |
self.protocol.connection_made(self.transport) | |
self.is_secure = True | |
logger.debug("Secure connection to %s:%s established", self.host, self.port) | |
if self.owner: | |
await self.owner.connection_made(True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Copyright 2019 aiohomekit team | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import logging | |
from typing import Union | |
from aiohomekit.exceptions import HttpException | |
logger = logging.getLogger(__name__) | |
class HttpResponse(object): | |
STATE_PRE_STATUS = 0 | |
STATE_HEADERS = 1 | |
STATE_BODY = 2 | |
STATE_DONE = 3 | |
def __init__(self) -> None: | |
self._state = HttpResponse.STATE_PRE_STATUS | |
self._raw_response = bytearray() | |
self._is_ready = False | |
self._is_chunked = False | |
self._had_empty_chunk = False | |
self._content_length = -1 | |
self.version = None | |
self.code = None | |
self.reason = None | |
self.headers = [] | |
self.body = bytearray() | |
def parse(self, part: Union[bytearray, bytes]) -> bytearray: | |
self._raw_response += part | |
pos = self._raw_response.find(b"\r\n") | |
while pos != -1: | |
line = self._raw_response[:pos] | |
self._raw_response = self._raw_response[pos + 2 :] | |
if self._state == HttpResponse.STATE_PRE_STATUS: | |
# parse status line | |
line = line.split(b" ", 2) | |
if len(line) != 3: | |
raise HttpException("Malformed status line.") | |
self.version = line[0].decode() | |
self.code = int(line[1]) | |
self.reason = line[2].decode() | |
self._state = HttpResponse.STATE_HEADERS | |
elif self._state == HttpResponse.STATE_HEADERS and line == b"": | |
# this is the empty line after the headers | |
self._state = HttpResponse.STATE_BODY | |
elif self._state == HttpResponse.STATE_HEADERS: | |
# parse a header line | |
line = line.split(b":", 1) | |
name = line[0].decode() | |
value = line[1].decode().strip() | |
if name == "Transfer-Encoding": | |
if value == "chunked": | |
self._is_chunked = True | |
elif name == "Content-Length": | |
self._content_length = int(value) | |
self.headers.append((name, value)) | |
elif self._state == HttpResponse.STATE_BODY: | |
if self._is_chunked: | |
length = int(line, 16) | |
if length + 2 > len(self._raw_response): | |
self._raw_response = line + b"\r\n" + self._raw_response | |
# the remaining bytes in raw response are not sufficient. bail out and wait for an other call. | |
break | |
if length == 0: | |
self._had_empty_chunk = True | |
self._state = HttpResponse.STATE_DONE | |
self._raw_response = self._raw_response[length + 2 :] | |
else: | |
line = self._raw_response[:length] | |
self.body += line | |
self._raw_response = self._raw_response[length + 2 :] | |
if self._content_length > -1: | |
self.body += self._raw_response | |
self._raw_response = bytearray() | |
else: | |
raise HttpException("Unknown parser state") | |
pos = self._raw_response.find(b"\r\n") | |
if self._state == HttpResponse.STATE_BODY and self._content_length > 0: | |
remaining = self._content_length - len(self.body) | |
self.body += self._raw_response[:remaining] | |
self._raw_response = self._raw_response[remaining:] | |
if self.is_read_completely(): | |
# Whatever is left in the buffer is part of the next request | |
if len(self._raw_response) > 0: | |
logger.debug( | |
"Bytes left in buffer after parsing packet: %r", self._raw_response | |
) | |
return self._raw_response | |
return bytearray() | |
def read(self): | |
""" | |
Returns the body of the response. | |
:return: The read body or None if no body content was read yet | |
""" | |
return self.body | |
def is_read_completely(self) -> bool: | |
if self._is_chunked: | |
return self._had_empty_chunk | |
if self._state < HttpResponse.STATE_BODY: | |
return False | |
if self._content_length != -1: | |
return len(self.body) == self._content_length | |
return True | |
def get_http_name(self) -> str: | |
""" | |
Returns the HTTP name (e.g. HTTP or EVENT). | |
:return: The name or None if the status line was not yet read | |
""" | |
if self.version is not None: | |
return self.version.split("/")[0] | |
return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment