Source code for sqlalchemy_mptt.events

#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2014 uralbash <root@uralbash.ru>
#
# Distributed under terms of the MIT license.

"""
SQLAlchemy events extension
"""
# standard library
import weakref

# SQLAlchemy
from sqlalchemy import and_, case, event, select, inspection
from sqlalchemy.orm import object_session
from sqlalchemy.sql import func
from sqlalchemy.orm.base import NO_VALUE


[docs]def _insert_subtree( table, connection, node_size, node_pos_left, node_pos_right, parent_pos_left, parent_pos_right, subtree, parent_tree_id, parent_level, node_level, left_sibling, table_pk ): # step 1: rebuild inserted subtree delta_lft = left_sibling['lft'] + 1 if not left_sibling['is_parent']: delta_lft = left_sibling['rgt'] + 1 delta_rgt = delta_lft + node_size - 1 connection.execute( table.update( table_pk.in_(subtree) ).values( lft=table.c.lft - node_pos_left + delta_lft, rgt=table.c.rgt - node_pos_right + delta_rgt, level=table.c.level - node_level + parent_level + 1, tree_id=parent_tree_id ) ) # step 2: update key of right side connection.execute( table.update( and_( table.c.rgt > delta_lft - 1, table_pk.notin_(subtree), table.c.tree_id == parent_tree_id ) ).values( rgt=table.c.rgt + node_size, lft=case( [ ( table.c.lft > left_sibling['lft'], table.c.lft + node_size ) ], else_=table.c.lft ) ) )
def _get_tree_table(mapper): for table in mapper.tables: if all(key in table.c for key in ['level', 'lft', 'rgt', 'parent_id']): return table
[docs]def mptt_before_insert(mapper, connection, instance): """ Based on example https://bitbucket.org/zzzeek/sqlalchemy/src/73095b353124/examples/nested_sets/nested_sets.py?at=master """ table = _get_tree_table(mapper) db_pk = instance.get_pk_column() table_pk = getattr(table.c, db_pk.name) if instance.parent_id is None: instance.left = 1 instance.right = 2 instance.level = instance.get_default_level() tree_id = connection.scalar( select( [ func.max(table.c.tree_id) + 1 ] ) ) or 1 instance.tree_id = tree_id else: (parent_pos_left, parent_pos_right, parent_tree_id, parent_level) = connection.execute( select( [ table.c.lft, table.c.rgt, table.c.tree_id, table.c.level ] ).where( table_pk == instance.parent_id ) ).fetchone() # Update key of right side connection.execute( table.update( and_(table.c.rgt >= parent_pos_right, table.c.tree_id == parent_tree_id) ).values( lft=case( [ ( table.c.lft > parent_pos_right, table.c.lft + 2 ) ], else_=table.c.lft ), rgt=case( [ ( table.c.rgt >= parent_pos_right, table.c.rgt + 2 ) ], else_=table.c.rgt ) ) ) instance.level = parent_level + 1 instance.tree_id = parent_tree_id instance.left = parent_pos_right instance.right = parent_pos_right + 1
[docs]def mptt_before_delete(mapper, connection, instance, delete=True): table = _get_tree_table(mapper) tree_id = instance.tree_id pk = getattr(instance, instance.get_pk_name()) db_pk = instance.get_pk_column() table_pk = getattr(table.c, db_pk.name) lft, rgt = connection.execute( select( [ table.c.lft, table.c.rgt ] ).where( table_pk == pk ) ).fetchone() delta = rgt - lft + 1 if delete: mapper.base_mapper.confirm_deleted_rows = False connection.execute( table.delete( table_pk == pk ) ) if instance.parent_id is not None or not delete: """ Update key of current tree UPDATE tree SET left_id = CASE WHEN left_id > $leftId THEN left_id - $delta ELSE left_id END, right_id = CASE WHEN right_id >= $rightId THEN right_id - $delta ELSE right_id END """ connection.execute( table.update( and_( table.c.rgt > rgt, table.c.tree_id == tree_id ) ).values( lft=case( [ ( table.c.lft > lft, table.c.lft - delta ) ], else_=table.c.lft ), rgt=case( [ ( table.c.rgt >= rgt, table.c.rgt - delta ) ], else_=table.c.rgt ) ) )
[docs]def mptt_before_update(mapper, connection, instance): """ Based on this example: http://stackoverflow.com/questions/889527/move-node-in-nested-set """ node_id = getattr(instance, instance.get_pk_name()) table = _get_tree_table(mapper) db_pk = instance.get_pk_column() default_level = instance.get_default_level() table_pk = getattr(table.c, db_pk.name) mptt_move_inside = None left_sibling = None left_sibling_tree_id = None if hasattr(instance, 'mptt_move_inside'): mptt_move_inside = instance.mptt_move_inside if hasattr(instance, 'mptt_move_before'): ( right_sibling_left, right_sibling_right, right_sibling_parent, right_sibling_level, right_sibling_tree_id ) = connection.execute( select( [ table.c.lft, table.c.rgt, table.c.parent_id, table.c.level, table.c.tree_id ] ).where( table_pk == instance.mptt_move_before ) ).fetchone() current_lvl_nodes = connection.execute( select( [ table.c.lft, table.c.rgt, table.c.parent_id, table.c.tree_id ] ).where( and_( table.c.level == right_sibling_level, table.c.tree_id == right_sibling_tree_id, table.c.lft < right_sibling_left ) ) ).fetchall() if current_lvl_nodes: ( left_sibling_left, left_sibling_right, left_sibling_parent, left_sibling_tree_id ) = current_lvl_nodes[-1] instance.parent_id = left_sibling_parent left_sibling = { 'lft': left_sibling_left, 'rgt': left_sibling_right, 'is_parent': False } # if move_before to top level elif not right_sibling_parent: left_sibling_tree_id = right_sibling_tree_id - 1 # if placed after a particular node if hasattr(instance, 'mptt_move_after'): ( left_sibling_left, left_sibling_right, left_sibling_parent, left_sibling_tree_id ) = connection.execute( select( [ table.c.lft, table.c.rgt, table.c.parent_id, table.c.tree_id ] ).where( table_pk == instance.mptt_move_after ) ).fetchone() instance.parent_id = left_sibling_parent left_sibling = { 'lft': left_sibling_left, 'rgt': left_sibling_right, 'is_parent': False } """ Get subtree from node SELECT id, name, level FROM my_tree WHERE left_key >= $left_key AND right_key <= $right_key ORDER BY left_key """ subtree = connection.execute( select([table_pk]) .where( and_( table.c.lft >= instance.left, table.c.rgt <= instance.right, table.c.tree_id == instance.tree_id ) ).order_by( table.c.lft ) ).fetchall() subtree = [x[0] for x in subtree] """ step 0: Initialize parameters. Put there left and right position of moving node """ ( node_pos_left, node_pos_right, node_tree_id, node_parent_id, node_level ) = connection.execute( select( [ table.c.lft, table.c.rgt, table.c.tree_id, table.c.parent_id, table.c.level ] ).where( table_pk == node_id ) ).fetchone() # if instance just update w/o move # XXX why this str() around parent_id comparison? if not left_sibling \ and str(node_parent_id) == str(instance.parent_id) \ and not mptt_move_inside: if left_sibling_tree_id is None: return # fix tree shorting if instance.parent_id is not None: ( parent_id, parent_pos_right, parent_pos_left, parent_tree_id, parent_level ) = connection.execute( select( [ table_pk, table.c.rgt, table.c.lft, table.c.tree_id, table.c.level ] ).where( table_pk == instance.parent_id ) ).fetchone() if node_parent_id is None and node_tree_id == parent_tree_id: instance.parent_id = None return # delete from old tree mptt_before_delete(mapper, connection, instance, False) if instance.parent_id is not None: """ Put there right position of new parent node (there moving node should be moved) """ ( parent_id, parent_pos_right, parent_pos_left, parent_tree_id, parent_level ) = connection.execute( select( [ table_pk, table.c.rgt, table.c.lft, table.c.tree_id, table.c.level ] ).where( table_pk == instance.parent_id ) ).fetchone() # 'size' of moving node (including all it's sub nodes) node_size = node_pos_right - node_pos_left + 1 # left sibling node if not left_sibling: left_sibling = { 'lft': parent_pos_left, 'rgt': parent_pos_right, 'is_parent': True } # insert subtree in exist tree instance.tree_id = parent_tree_id _insert_subtree( table, connection, node_size, node_pos_left, node_pos_right, parent_pos_left, parent_pos_right, subtree, parent_tree_id, parent_level, node_level, left_sibling, table_pk ) else: # if insert after if left_sibling_tree_id or left_sibling_tree_id == 0: tree_id = left_sibling_tree_id + 1 connection.execute( table.update( table.c.tree_id > left_sibling_tree_id ).values( tree_id=table.c.tree_id + 1 ) ) # if just insert else: tree_id = connection.scalar( select( [ func.max(table.c.tree_id) + 1 ] ) ) connection.execute( table.update( table_pk.in_( subtree ) ).values( lft=table.c.lft - node_pos_left + 1, rgt=table.c.rgt - node_pos_left + 1, level=table.c.level - node_level + default_level, tree_id=tree_id ) )
class _WeakDictBasedSet(weakref.WeakKeyDictionary, object): """ In absence of a default weakset implementation, provide our own dict based solution. """ def add(self, obj): self[obj] = None def discard(self, obj): super(_WeakDictBasedSet, self).pop(obj, None) def pop(self): return self.popitem()[0] class _WeakDefaultDict(weakref.WeakKeyDictionary, object): def __getitem__(self, key): try: return super(_WeakDefaultDict, self).__getitem__(key) except KeyError: self[key] = value = _WeakDictBasedSet() return value
[docs]class TreesManager(object): """ Manages events dispatching for all subclasses of a given class. """ def __init__(self, base_class): self.base_class = base_class self.classes = set() self.instances = _WeakDefaultDict() def register_events(self, remove=False): for e, h in ( ('before_insert', self.before_insert), ('before_update', self.before_update), ('before_delete', self.before_delete), ): is_event_exist = event.contains(self.base_class, e, h) if remove and is_event_exist: event.remove(self.base_class, e, h) elif not is_event_exist: event.listen(self.base_class, e, h, propagate=True) return self def register_factory(self, sessionmaker): """ Registers this TreesManager instance to respond on `after_flush_postexec` events on the given session or session factory. This method returns the original argument, so that it can be used by wrapping an already exisiting instance: .. code-block:: python :linenos: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, mapper from sqlalchemy_mptt.mixins import BaseNestedSets engine = create_engine('...') trees_manager = TreesManager(BaseNestedSets) trees_manager.register_mapper(mapper) Session = tree_manager.register_factory( sessionmaker(bind=engine) ) A reference to this method, bound to a default instance of this class and already registered to a mapper, is importable directly from `sqlalchemy_mptt`: .. code-block:: python :linenos: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy_mptt import mptt_sessionmaker engine = create_engine('...') Session = mptt_sessionmaker(sessionmaker(bind=engine)) """ event.listen(sessionmaker, 'after_flush_postexec', self.after_flush_postexec) return sessionmaker def before_insert(self, mapper, connection, instance): session = object_session(instance) self.instances[session].add(instance) mptt_before_insert(mapper, connection, instance) def before_update(self, mapper, connection, instance): session = object_session(instance) self.instances[session].add(instance) mptt_before_update(mapper, connection, instance) def before_delete(self, mapper, connection, instance): session = object_session(instance) self.instances[session].discard(instance) mptt_before_delete(mapper, connection, instance) def after_flush_postexec(self, session, context): """ Event listener to recursively expire `left` and `right` attributes the parents of all modified instances part of this flush. """ instances = self.instances[session] while instances: instance = instances.pop() if instance not in session: continue parent = self.get_parent_value(instance) while parent != NO_VALUE and parent is not None: instances.discard(parent) session.expire(parent, ['left', 'right', 'tree_id', 'level']) parent = self.get_parent_value(parent) else: session.expire(instance, ['left', 'right', 'tree_id', 'level']) self.expire_session_for_children(session, instance) @staticmethod def get_parent_value(instance): return inspection.inspect(instance).attrs.parent.loaded_value @staticmethod def expire_session_for_children(session, instance): children = instance.children def expire_recursively(node): children = node.children for item in children: session.expire(item, ['left', 'right', 'tree_id', 'level']) expire_recursively(item) if children != NO_VALUE and children is not None: for item in children: session.expire(item, ['left', 'right', 'tree_id', 'level']) expire_recursively(item)