Source code for sqlalchemy_mptt.events

#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2014 uralbash <root@uralbash.ru>
#
# Distributed under terms of the MIT license.

"""
SQLAlchemy events extension
"""
import weakref

from sqlalchemy import and_, case, event, inspection, select
from sqlalchemy.orm import object_session
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.sql import func


[docs]def _insert_subtree(table, connection, node_size, node_pos_left, node_pos_right, parent_pos_left, parent_pos_right, subtree, parent_tree_id, parent_level, node_level, left_sibling, table_pk): # step 1: rebuild inserted subtree delta_lft = left_sibling['lft'] + 1 if not left_sibling['is_parent']: delta_lft = left_sibling['rgt'] + 1 delta_rgt = delta_lft + node_size - 1 connection.execute( table.update(table_pk.in_(subtree)) .values( lft=table.c.lft - node_pos_left + delta_lft, rgt=table.c.rgt - node_pos_right + delta_rgt, level=table.c.level - node_level + parent_level + 1, tree_id=parent_tree_id ) ) # step 2: update key of right side connection.execute( table.update( and_(table.c.rgt > delta_lft - 1, table_pk.notin_(subtree), table.c.tree_id == parent_tree_id) ).values( rgt=table.c.rgt + node_size, lft=case( [(table.c.lft > left_sibling['lft'], table.c.lft + node_size)], else_=table.c.lft ) ) )
def _get_tree_table(mapper): for table in mapper.tables: if all(key in table.c for key in ['level', 'lft', 'rgt', 'parent_id']): return table
[docs]def mptt_before_insert(mapper, connection, instance): """ Based on example https://bitbucket.org/zzzeek/sqlalchemy/src/73095b353124/examples/nested_sets/nested_sets.py?at=master """ table = _get_tree_table(mapper) db_pk = instance.get_pk_column() table_pk = getattr(table.c, db_pk.name) if not instance.parent_id: instance.left = 1 instance.right = 2 instance.level = 1 tree_id = connection.scalar( select([func.max(table.c.tree_id) + 1])) or 1 instance.tree_id = tree_id else: (parent_pos_left, parent_pos_right, parent_tree_id, parent_level) = connection.execute( select([table.c.lft, table.c.rgt, table.c.tree_id, table.c.level]). where(table_pk == instance.parent_id) ).fetchone() # Update key of right side connection.execute( table.update( and_(table.c.rgt >= parent_pos_right, table.c.tree_id == parent_tree_id) ).values( lft=case( [(table.c.lft > parent_pos_right, table.c.lft + 2)], else_=table.c.lft ), rgt=case( [(table.c.rgt >= parent_pos_right, table.c.rgt + 2)], else_=table.c.rgt ) ) ) instance.level = parent_level + 1 instance.tree_id = parent_tree_id instance.left = parent_pos_right instance.right = parent_pos_right + 1
[docs]def mptt_before_delete(mapper, connection, instance, delete=True): table = _get_tree_table(mapper) tree_id = instance.tree_id pk = getattr(instance, instance.get_pk_name()) db_pk = instance.get_pk_column() table_pk = getattr(table.c, db_pk.name) lft, rgt = connection.execute( select([table.c.lft, table.c.rgt]).where(table_pk == pk) ).fetchone() delta = rgt - lft + 1 if delete: mapper.base_mapper.confirm_deleted_rows = False connection.execute( table.delete(table_pk == pk) ) if instance.parent_id or not delete: """ Update key of current tree UPDATE tree SET left_id = CASE WHEN left_id > $leftId THEN left_id - $delta ELSE left_id END, right_id = CASE WHEN right_id >= $rightId THEN right_id - $delta ELSE right_id END """ connection.execute( table.update( and_(table.c.rgt > rgt, table.c.tree_id == tree_id)) .values( lft=case( [(table.c.lft > lft, table.c.lft - delta)], else_=table.c.lft ), rgt=case( [(table.c.rgt >= rgt, table.c.rgt - delta)], else_=table.c.rgt ) ) )
[docs]def mptt_before_update(mapper, connection, instance): """ Based on this example: http://stackoverflow.com/questions/889527/move-node-in-nested-set """ node_id = getattr(instance, instance.get_pk_name()) table = _get_tree_table(mapper) db_pk = instance.get_pk_column() table_pk = getattr(table.c, db_pk.name) mptt_move_inside = None left_sibling = None left_sibling_tree_id = None if hasattr(instance, 'mptt_move_inside'): mptt_move_inside = instance.mptt_move_inside if hasattr(instance, 'mptt_move_before'): (right_sibling_left, right_sibling_right, right_sibling_parent, right_sibling_level, right_sibling_tree_id) = connection.execute( select([table.c.lft, table.c.rgt, table.c.parent_id, table.c.level, table.c.tree_id]). where(table_pk == instance.mptt_move_before) ).fetchone() current_lvl_nodes = connection.execute( select([table.c.lft, table.c.rgt, table.c.parent_id, table.c.tree_id]). where(and_(table.c.level == right_sibling_level, table.c.tree_id == right_sibling_tree_id, table.c.lft < right_sibling_left)) ).fetchall() if current_lvl_nodes: (left_sibling_left, left_sibling_right, left_sibling_parent, left_sibling_tree_id) = current_lvl_nodes[-1] instance.parent_id = left_sibling_parent left_sibling = {'lft': left_sibling_left, 'rgt': left_sibling_right, 'is_parent': False} # if move_before to top level elif not right_sibling_parent: left_sibling_tree_id = right_sibling_tree_id - 1 # if placed after a particular node if hasattr(instance, 'mptt_move_after'): (left_sibling_left, left_sibling_right, left_sibling_parent, left_sibling_tree_id) = connection.execute( select([table.c.lft, table.c.rgt, table.c.parent_id, table.c.tree_id]). where(table_pk == instance.mptt_move_after) ).fetchone() instance.parent_id = left_sibling_parent left_sibling = {'lft': left_sibling_left, 'rgt': left_sibling_right, 'is_parent': False} """ Get subtree from node SELECT id, name, level FROM my_tree WHERE left_key >= $left_key AND right_key <= $right_key ORDER BY left_key """ subtree = connection.execute( select([table_pk]) .where(and_(table.c.lft >= instance.left, table.c.rgt <= instance.right, table.c.tree_id == instance.tree_id)) .order_by(table.c.lft) ).fetchall() subtree = [x[0] for x in subtree] """ step 0: Initialize parameters. Put there left and right position of moving node """ (node_pos_left, node_pos_right, node_tree_id, node_parent_id, node_level) = connection.execute( select([table.c.lft, table.c.rgt, table.c.tree_id, table.c.parent_id, table.c.level]) .where(table_pk == node_id) ).fetchone() # if instance just update w/o move if not left_sibling\ and str(node_parent_id) == str(instance.parent_id)\ and not mptt_move_inside: if left_sibling_tree_id is None: return # fix tree shorting if instance.parent_id: (parent_id, parent_pos_right, parent_pos_left, parent_tree_id, parent_level) = connection.execute( select([table_pk, table.c.rgt, table.c.lft, table.c.tree_id, table.c.level]) .where(table_pk == instance.parent_id) ).fetchone() if not node_parent_id and node_tree_id == parent_tree_id: instance.parent_id = None return # delete from old tree mptt_before_delete(mapper, connection, instance, False) if instance.parent_id: """ Put there right position of new parent node (there moving node should be moved) """ (parent_id, parent_pos_right, parent_pos_left, parent_tree_id, parent_level) = connection.execute( select([table_pk, table.c.rgt, table.c.lft, table.c.tree_id, table.c.level]) .where(table_pk == instance.parent_id) ).fetchone() # 'size' of moving node (including all it's sub nodes) node_size = node_pos_right - node_pos_left + 1 # left sibling node if not left_sibling: left_sibling = {'lft': parent_pos_left, 'rgt': parent_pos_right, 'is_parent': True} # insert subtree in exist tree instance.tree_id = parent_tree_id _insert_subtree(table, connection, node_size, node_pos_left, node_pos_right, parent_pos_left, parent_pos_right, subtree, parent_tree_id, parent_level, node_level, left_sibling, table_pk) else: # if insert after if left_sibling_tree_id or left_sibling_tree_id == 0: tree_id = left_sibling_tree_id + 1 connection.execute( table.update(table.c.tree_id > left_sibling_tree_id) .values(tree_id=table.c.tree_id + 1) ) # if just insert else: tree_id = connection.scalar( select([func.max(table.c.tree_id) + 1])) connection.execute( table.update(table_pk.in_(subtree)) .values( lft=table.c.lft - node_pos_left + 1, rgt=table.c.rgt - node_pos_left + 1, level=table.c.level - node_level + 1, tree_id=tree_id ) )
class _WeakDictBasedSet(weakref.WeakKeyDictionary, object): """ In absence of a default weakset implementation, provide our own dict based solution. """ def add(self, obj): self[obj] = None def discard(self, obj): super(_WeakDictBasedSet, self).pop(obj, None) def pop(self): return self.popitem()[0] class _WeakDefaultDict(weakref.WeakKeyDictionary, object): def __getitem__(self, key): try: return super(_WeakDefaultDict, self).__getitem__(key) except KeyError: self[key] = value = _WeakDictBasedSet() return value
[docs]class TreesManager(object): """ Manages events dispatching for all subclasses of a given class. """ def __init__(self, base_class): self.base_class = base_class self.classes = set() self.instances = _WeakDefaultDict() def register_events(self, remove=False): for e, h in ( ('before_insert', self.before_insert), ('before_update', self.before_update), ('before_delete', self.before_delete), ): is_event_exist = event.contains(self.base_class, e, h) if remove and is_event_exist: event.remove(self.base_class, e, h) elif not is_event_exist: event.listen(self.base_class, e, h, propagate=True) return self def register_factory(self, sessionmaker): """ Registers this TreesManager instance to respond on `after_flush_postexec` events on the given session or session factory. This method returns the original argument, so that it can be used by wrapping an already exisiting instance: .. code-block:: python :linenos: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, mapper from sqlalchemy_mptt.mixins import BaseNestedSets engine = create_engine('...') trees_manager = TreesManager(BaseNestedSets) trees_manager.register_mapper(mapper) Session = tree_manager.register_factory( sessionmaker(bind=engine) ) A reference to this method, bound to a default instance of this class and already registered to a mapper, is importable directly from `sqlalchemy_mptt`: .. code-block:: python :linenos: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy_mptt import mptt_sessionmaker engine = create_engine('...') Session = mptt_sessionmaker(sessionmaker(bind=engine)) """ event.listen(sessionmaker, 'after_flush_postexec', self.after_flush_postexec) return sessionmaker def before_insert(self, mapper, connection, instance): session = object_session(instance) self.instances[session].add(instance) mptt_before_insert(mapper, connection, instance) def before_update(self, mapper, connection, instance): session = object_session(instance) self.instances[session].add(instance) mptt_before_update(mapper, connection, instance) def before_delete(self, mapper, connection, instance): session = object_session(instance) self.instances[session].discard(instance) mptt_before_delete(mapper, connection, instance) def after_flush_postexec(self, session, context): """ Event listener to recursively expire `left` and `right` attributes the parents of all modified instances part of this flush. """ instances = self.instances[session] while instances: instance = instances.pop() if instance not in session: continue parent = self.get_parent_value(instance) while parent != NO_VALUE and parent is not None: instances.discard(parent) session.expire(parent, ['left', 'right', 'tree_id', 'level']) parent = self.get_parent_value(parent) else: session.expire(instance, ['left', 'right', 'tree_id', 'level']) self.expire_session_for_children(session, instance) @staticmethod def get_parent_value(instance): return inspection.inspect(instance).attrs.parent.loaded_value @staticmethod def expire_session_for_children(session, instance): children = instance.children def expire_recursively(node): children = node.children for item in children: session.expire(item, ['left', 'right', 'tree_id', 'level']) expire_recursively(item) if children != NO_VALUE and children is not None: for item in children: session.expire(item, ['left', 'right', 'tree_id', 'level']) expire_recursively(item)
  Read the Docs
v: latest  
Versions
latest
Free document hosting provided by Read the Docs.