1import abc
2import os
3import struct
4import subprocess
5
6from google.auth import exceptions
7from google.oauth2.webauthn_types import GetRequest, GetResponse
8
9
10class WebAuthnHandler(abc.ABC):
11 @abc.abstractmethod
12 def is_available(self) -> bool:
13 """Check whether this WebAuthn handler is available"""
14 raise NotImplementedError("is_available method must be implemented")
15
16 @abc.abstractmethod
17 def get(self, get_request: GetRequest) -> GetResponse:
18 """WebAuthn get (assertion)"""
19 raise NotImplementedError("get method must be implemented")
20
21
22class PluginHandler(WebAuthnHandler):
23 """Offloads WebAuthn get reqeust to a pluggable command-line tool.
24
25 Offloads WebAuthn get to a plugin which takes the form of a
26 command-line tool. The command-line tool is configurable via the
27 PluginHandler._ENV_VAR environment variable.
28
29 The WebAuthn plugin should implement the following interface:
30
31 Communication occurs over stdin/stdout, and messages are both sent and
32 received in the form:
33
34 [4 bytes - payload size (little-endian)][variable bytes - json payload]
35 """
36
37 _ENV_VAR = "GOOGLE_AUTH_WEBAUTHN_PLUGIN"
38
39 def is_available(self) -> bool:
40 try:
41 self._find_plugin()
42 except Exception:
43 return False
44 else:
45 return True
46
47 def get(self, get_request: GetRequest) -> GetResponse:
48 request_json = get_request.to_json()
49 cmd = self._find_plugin()
50 response_json = self._call_plugin(cmd, request_json)
51 return GetResponse.from_json(response_json)
52
53 def _call_plugin(self, cmd: str, input_json: str) -> str:
54 # Calculate length of input
55 input_length = len(input_json)
56 length_bytes_le = struct.pack("<I", input_length)
57 request = length_bytes_le + input_json.encode()
58
59 # Call plugin
60 process_result = subprocess.run(
61 [cmd], input=request, capture_output=True, check=True
62 )
63
64 # Check length of response
65 response_len_le = process_result.stdout[:4]
66 response_len = struct.unpack("<I", response_len_le)[0]
67 response = process_result.stdout[4:]
68 if response_len != len(response):
69 raise exceptions.MalformedError(
70 "Plugin response length {} does not match data {}".format(
71 response_len, len(response)
72 )
73 )
74 return response.decode()
75
76 def _find_plugin(self) -> str:
77 plugin_cmd = os.environ.get(PluginHandler._ENV_VAR)
78 if plugin_cmd is None:
79 raise exceptions.InvalidResource(
80 "{} env var is not set".format(PluginHandler._ENV_VAR)
81 )
82 return plugin_cmd