Add type hints for fods.

This commit is contained in:
ljnsn 2022-11-16 02:00:31 +01:00
parent e9e4b3b272
commit 538120ce82
1 changed files with 10 additions and 8 deletions

View File

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, Union
from lxml import etree from lxml import etree
@ -15,11 +16,11 @@ TABLE_CELL_REPEATED_ATTRIB = "number-columns-repeated"
VALUE_TYPE_ATTRIB = "value-type" VALUE_TYPE_ATTRIB = "value-type"
def get_doc(file_or_path: Path): def get_doc(file_or_path: Path) -> etree._ElementTree:
return etree.parse(str(file_or_path)) return etree.parse(str(file_or_path))
def get_sheet(spreadsheet, sheet_id): def get_sheet(spreadsheet: etree._Element, sheet_id: Union[str, int]) -> etree._Element:
namespaces = spreadsheet.nsmap namespaces = spreadsheet.nsmap
if isinstance(sheet_id, str): if isinstance(sheet_id, str):
sheet = spreadsheet.find( sheet = spreadsheet.find(
@ -35,7 +36,10 @@ def get_sheet(spreadsheet, sheet_id):
return tables[sheet_id - 1] return tables[sheet_id - 1]
def get_rows(doc, sheet_id): def get_rows(
doc: etree._ElementTree,
sheet_id: Union[str, int],
) -> Iterator[etree._Element]:
if not isinstance(sheet_id, (str, int)): if not isinstance(sheet_id, (str, int)):
raise ValueError("Sheet id has to be either `str` or `int`") raise ValueError("Sheet id has to be either `str` or `int`")
root = doc.getroot() root = doc.getroot()
@ -44,18 +48,16 @@ def get_rows(doc, sheet_id):
SPREADSHEET_TAG, namespaces=namespaces SPREADSHEET_TAG, namespaces=namespaces
) )
sheet = get_sheet(spreadsheet, sheet_id) sheet = get_sheet(spreadsheet, sheet_id)
rows = sheet.findall(TABLE_ROW_TAG, namespaces=namespaces) return sheet.iterfind(TABLE_ROW_TAG, namespaces=namespaces)
for row in rows:
yield row
def is_float(cell): def is_float(cell: etree._Element) -> bool:
return ( return (
cell.attrib.get(f"{{{cell.nsmap[OFFICE_KEY]}}}{VALUE_TYPE_ATTRIB}") == "float" cell.attrib.get(f"{{{cell.nsmap[OFFICE_KEY]}}}{VALUE_TYPE_ATTRIB}") == "float"
) )
def get_value(cell, parsed=False): def get_value(cell: etree._Element, parsed: bool = False) -> tuple[Any, int]:
text = cell.find(TABLE_CELL_TEXT_TAG, namespaces=cell.nsmap) text = cell.find(TABLE_CELL_TEXT_TAG, namespaces=cell.nsmap)
if text is None: if text is None:
return None, 0 return None, 0