Source code for koapy.utils.store.SQLiteStoreLibrary

import hashlib

import pandas as pd
import pytz

from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import MetaData, Table

from .misc.VersionedItem import VersionedItem
from .sqlalchemy.Timestamp import Timestamp


[docs]class SQLiteStoreLibrary: def __init__(self, store, library): self._store = store self._library = library self._engine = self._store._engine self._session = self._store._session
[docs] def list_symbols(self, all_symbols=False, snapshot=None): if snapshot is not None: snapshot = self._library.get_snapshot(snapshot) symbols = snapshot.get_symbols() else: symbols = self._library.get_symbols(deleted=all_symbols) symbols = [symbol.name for symbol in symbols] return symbols
[docs] def has_symbol(self, symbol): try: _symbol = self._library.get_symbol(symbol) return True except NoResultFound: return False
[docs] def list_versions(self, symbol=None, snapshot=None, latest_only=False): if symbol is not None: symbol = self._library.get_symbol(symbol) versions = symbol.get_versions() elif snapshot is not None: snapshot = self._library.get_snapshot(snapshot) versions = snapshot.get_versions() elif latest_only: versions = self._library.get_latest_versions() else: versions = self._library.get_versions() versions = [ { "symbol": version.symbol.name, "version": version.version, "deleted": version.deleted, "timestamp": version.timestamp, "snapshots": [snapshot.name for snapshot in version.get_snapshots()], } for version in versions ] return versions
def _hash_data(self, data): return hashlib.sha256( pd.util.hash_pandas_object(data).values.tobytes() ).hexdigest()
[docs] def write(self, symbol, data, metadata=None, prune_previous_version=True, **kwargs): symbol_name = symbol try: symbol = self._library.get_or_create_symbol(symbol_name) if data is None: _version = symbol.create_new_version( user_metadata=metadata, deleted=metadata and metadata.get("deleted") ) self._session.commit() else: index_original_names = [name for name in data.index.names] index_table_names = [ name if name is not None else "index_%d" % i for i, name in enumerate(data.index.names) ] datetime_columns = [ (name, dtype) for name, dtype in data.dtypes.items() if pd.api.types.is_datetime64_any_dtype(dtype) ] column_timezones = { name: str(dtype.tz) for name, dtype in datetime_columns if hasattr(dtype, "tz") } index_label = index_table_names dtype = { name: Timestamp(getattr(dtype, "tz", False)) for name, dtype in datetime_columns } index_col = index_table_names parse_dates = [name for name, dtype in datetime_columns] pandas_metadata = { "read_sql_table": { "index_col": index_col, "parse_dates": parse_dates, "index_names": index_original_names, "column_timezones": column_timezones, }, } table_name = self._hash_data(data) _version = symbol.create_new_version( table_name=table_name, user_metadata=metadata, pandas_metadata=pandas_metadata, deleted=metadata and metadata.get("deleted"), ) self._session.commit() to_sql_kwargs = { "index_label": index_label, "dtype": dtype, "if_exists": "replace", } data.to_sql(table_name, self._engine, **to_sql_kwargs) except: self._session.rollback() raise if prune_previous_version: self._prune_previous_versions(symbol_name, **kwargs) return self.read(symbol_name)
def _get_version(self, symbol, as_of=None): symbol_name = symbol symbol = self._library.get_symbol(symbol_name) if as_of is None: version = symbol.get_latest_version() elif isinstance(as_of, int): version = symbol.get_version_by_number(as_of) elif isinstance(as_of, str): snapshot = self._library.get_snapshot(as_of) version = snapshot.get_version_of_symbol(symbol) else: raise ValueError("Invalid as_of argument") return version
[docs] def read_as_dataframe( self, symbol, as_of=None, time_column=None, start_time=None, end_time=None ): symbol_name = symbol version = self._get_version(symbol_name, as_of) data = None if version.table_name is not None: read_sql_table_kwargs = { "index_col": version.pandas_metadata["read_sql_table"]["index_col"], "parse_dates": version.pandas_metadata["read_sql_table"]["parse_dates"], } data = pd.read_sql_table( version.table_name, self._engine, **read_sql_table_kwargs ) index_names = version.pandas_metadata["read_sql_table"]["index_names"] index_names = tuple(index_names) if len(data.index.names) == 1: index_names = index_names[0] data.index.rename(index_names, inplace=True) # pylint: disable=no-member column_timezones = version.pandas_metadata["read_sql_table"][ "column_timezones" ] for column_name in version.pandas_metadata["read_sql_table"]["parse_dates"]: timezone = column_timezones.get(column_name, Timestamp.local_timezone) if isinstance(timezone, str): timezone = pytz.timezone(timezone) data[column_name] = ( data[column_name] .dt.tz_localize(Timestamp.utc) .dt.tz_convert(timezone) ) if column_name not in column_timezones: data[column_name] = data[column_name].dt.tz_localize(None) return VersionedItem( self._library.name, symbol_name, version.version, version.timestamp, data, version.user_metadata, )
[docs] def read_as_cursor( self, symbol, as_of=None, time_column=None, start_time=None, end_time=None ): symbol_name = symbol version = self._get_version(symbol_name, as_of) cursor = None if version.table_name is not None: records = Table(version.table_name, MetaData(), autoload_with=self._engine) statement = select(records) if time_column is not None: time_column = records.columns[ time_column ] # pylint: disable=unsubscriptable-object statement = statement.order_by(time_column) if start_time is not None: start_time = pd.Timestamp(start_time) if time_column is None: time_column = 0 time_column = records.columns[ time_column ] # pylint: disable=unsubscriptable-object statement = statement.order_by(time_column) if Timestamp.is_naive(start_time): start_time = start_time.tz_localize(Timestamp.local_timezone) start_time = start_time.astimezone(Timestamp.utc) statement = statement.where( time_column >= start_time ) # pylint: disable=unsubscriptable-object if end_time is not None: end_time = pd.Timestamp(end_time) if time_column is None: time_column = 0 time_column = records.columns[ time_column ] # pylint: disable=unsubscriptable-object statement = statement.order_by(time_column) if Timestamp.is_naive(end_time): end_time = end_time.tz_localize(Timestamp.local_timezone) end_time = end_time.astimezone(Timestamp.utc) statement = statement.where( time_column <= end_time ) # pylint: disable=unsubscriptable-object statement = statement.execution_options(stream_results=True) cursor = self._session.execute(statement) return VersionedItem( self._library.name, symbol_name, version.version, version.timestamp, cursor, version.user_metadata, )
[docs] def read(self, *args, **kwargs): return self.read_as_dataframe(*args, **kwargs)
def _prune_previous_versions(self, symbol, keep_mins=120): symbol_name = symbol symbol = self._library.get_symbol(symbol_name) prunable_verions = symbol.get_prunable_versions(keep_mins) try: for version in prunable_verions: version.delete() self._session.commit() except: self._session.rollback() raise def _delete_version(self, symbol, version): symbol = self._library.get_symbol(symbol) version = symbol.get_version_by_number(version, deleted=True) try: version.delete() self._session.commit() except: self._session.rollback() raise
[docs] def delete(self, symbol): symbol_name = symbol symbol = self._library.get_symbol(symbol_name) _sentinel = self.write( symbol_name, None, prune_previous_version=False, metadata={"deleted": True} ) self._prune_previous_versions(symbol_name, 0) assert not self.has_symbol(symbol_name)
[docs] def list_snapshots(self): snapshots = self._library.snapshots snapshots = [snapshot.name for snapshot in snapshots] return snapshots
[docs] def snapshot(self, snapshot): try: self._library.create_snapshot(snapshot) self._session.commit() except: self._session.rollback() raise
[docs] def delete_snapshot(self, snapshot): snapshot = self._library.get_snapshot(snapshot) try: snapshot.delete() self._session.commit() except: self._session.rollback() raise