#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2014 uralbash <root@uralbash.ru>
# Copyright © 2016 Jiri Kuncar <jiri.kuncar@gmail.com>
#
# Distributed under terms of the MIT license.
"""
SQLAlchemy nested sets mixin
"""
# SQLAlchemy
from sqlalchemy import Index, Column, Integer, ForeignKey, asc, desc
from sqlalchemy.orm import backref, relationship, object_session
from sqlalchemy.ext.hybrid import hybrid_method
from sqlalchemy.orm.session import Session
from sqlalchemy.ext.declarative import declared_attr
# local
from .events import _get_tree_table
[docs]class BaseNestedSets(object):
""" Base mixin for MPTT model.
Example:
.. code::
from sqlalchemy import Boolean, Column, create_engine, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_mptt.mixins import BaseNestedSets
Base = declarative_base()
class Tree(Base, BaseNestedSets):
__tablename__ = "tree"
id = Column(Integer, primary_key=True)
visible = Column(Boolean)
def __repr__(self):
return "<Node (%s)>" % self.id
"""
@declared_attr
def __table_args__(cls):
return (
Index("%s_lft_idx" % cls.__tablename__, cls.left.name),
Index("%s_rgt_idx" % cls.__tablename__, cls.right.name),
Index("%s_level_idx" % cls.__tablename__, cls.level.name),
)
@classmethod
def __declare_first__(cls):
cls.__mapper__.batch = False
[docs] @classmethod
def get_default_level(cls):
"""
Compatibility with Django MPTT: level value for root node.
See https://github.com/uralbash/sqlalchemy_mptt/issues/56
"""
return getattr(cls, "sqlalchemy_mptt_default_level", 1)
@classmethod
def get_pk_name(cls):
return getattr(cls, "sqlalchemy_mptt_pk_name", "id")
@classmethod
def get_pk_column(cls):
return getattr(cls, cls.get_pk_name())
def get_pk_value(self):
return getattr(self, self.get_pk_name())
@declared_attr
def tree_id(cls):
return Column("tree_id", Integer)
@declared_attr
def parent_id(cls):
pk = cls.get_pk_column()
if not pk.name:
pk.name = cls.get_pk_name()
return Column(
"parent_id",
pk.type,
ForeignKey("{}.{}".format(cls.__tablename__, pk.name), ondelete="CASCADE"),
)
@declared_attr
def parent(self):
return relationship(
self,
order_by=lambda: self.left,
foreign_keys=[self.parent_id],
remote_side="{}.{}".format(self.__name__, self.get_pk_name()),
backref=backref(
"children",
cascade="all,delete",
order_by=lambda: (self.tree_id, self.left),
),
)
@declared_attr
def left(cls):
return Column("lft", Integer, nullable=False)
@declared_attr
def right(cls):
return Column("rgt", Integer, nullable=False)
@declared_attr
def level(cls):
return Column("level", Integer, nullable=False, default=0)
[docs] @hybrid_method
def is_ancestor_of(self, other, inclusive=False):
""" class or instance level method which returns True if self is
ancestor (closer to root) of other else False. Optional flag
`inclusive` on whether or not to treat self as ancestor of self.
For example see:
* :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure`
"""
if inclusive:
return (
(self.tree_id == other.tree_id)
& (self.left <= other.left)
& (other.right <= self.right)
)
return (
(self.tree_id == other.tree_id)
& (self.left < other.left)
& (other.right < self.right)
)
[docs] @hybrid_method
def is_descendant_of(self, other, inclusive=False):
""" class or instance level method which returns True if self is
descendant (farther from root) of other else False. Optional flag
`inclusive` on whether or not to treat self as descendant of self.
For example see:
* :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure`
"""
return other.is_ancestor_of(self, inclusive)
[docs] def move_inside(self, parent_id):
""" Moving one node of tree inside another
For example see:
* :mod:`sqlalchemy_mptt.tests.cases.move_node.test_move_inside_function`
* :mod:`sqlalchemy_mptt.tests.cases.move_node.test_move_inside_to_the_same_parent_function`
""" # noqa
session = Session.object_session(self)
self.parent_id = parent_id
self.mptt_move_inside = parent_id
session.add(self)
[docs] def move_after(self, node_id):
""" Moving one node of tree after another
For example see :mod:`sqlalchemy_mptt.tests.cases.move_node.test_move_after_function`
""" # noqa
session = Session.object_session(self)
self.parent_id = self.parent_id
self.mptt_move_after = node_id
session.add(self)
[docs] def move_before(self, node_id):
""" Moving one node of tree before another
For example see:
* :mod:`sqlalchemy_mptt.tests.cases.move_node.test_move_before_function`
* :mod:`sqlalchemy_mptt.tests.cases.move_node.test_move_before_to_other_tree`
* :mod:`sqlalchemy_mptt.tests.cases.move_node.test_move_before_to_top_level`
""" # noqa
session = Session.object_session(self)
table = _get_tree_table(self.__mapper__)
pk = getattr(table.c, self.get_pk_column().name)
node = session.query(table).filter(pk == node_id).one()
self.parent_id = node.parent_id
self.mptt_move_before = node_id
session.add(self)
[docs] def leftsibling_in_level(self):
""" Node to the left of the current node at the same level
For example see
:mod:`sqlalchemy_mptt.tests.cases.get_tree.test_leftsibling_in_level`
""" # noqa
table = _get_tree_table(self.__mapper__)
session = Session.object_session(self)
current_lvl_nodes = (
session.query(table)
.filter_by(level=self.level)
.filter_by(tree_id=self.tree_id)
.filter(table.c.lft < self.left)
.order_by(table.c.lft)
.all()
)
if current_lvl_nodes:
return current_lvl_nodes[-1]
return None
@classmethod
def _node_to_dict(cls, node, json, json_fields):
""" Helper method for ``get_tree``.
"""
if json:
pk_name = node.get_pk_name()
# jqTree or jsTree format
result = {"id": getattr(node, pk_name), "label": node.__repr__()}
if json_fields:
result.update(json_fields(node))
else:
result = {"node": node}
return result
@classmethod
def _base_query(cls, session=None):
return session.query(cls)
def _base_query_obj(self, session=None):
if not session:
session = object_session(self)
return self._base_query(session)
@classmethod
def _base_order(cls, query, order=asc):
return (
query.order_by(order(cls.tree_id))
.order_by(order(cls.level))
.order_by(order(cls.left))
)
[docs] @classmethod
def get_tree(cls, session=None, json=False, json_fields=None, query=None):
""" This method generate tree of current node table in dict or json
format. You can make custom query with attribute ``query``. By default
it return all nodes in table.
Args:
session (:mod:`sqlalchemy.orm.session.Session`): SQLAlchemy session
Kwargs:
json (bool): if True return JSON jqTree format
json_fields (function): append custom fields in JSON
query (function): it takes :class:`sqlalchemy.orm.query.Query`
object as an argument, and returns in a modified form
::
def query(nodes):
return nodes.filter(node.__class__.tree_id.is_(node.tree_id))
node.get_tree(session=DBSession, json=True, query=query)
Example:
* :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_get_tree`
* :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_get_json_tree`
* :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_get_json_tree_with_custom_field`
""" # noqa
tree = []
nodes_of_level = {}
# handle custom query
nodes = cls._base_query(session)
if query:
nodes = query(nodes)
nodes = cls._base_order(nodes).all()
# search minimal level of nodes.
min_level = min([node.level for node in nodes] or [None])
def get_node_id(node):
return getattr(node, node.get_pk_name())
for node in nodes:
result = cls._node_to_dict(node, json, json_fields)
parent_id = node.parent_id
if node.level != min_level: # for cildren
# Find parent in the tree
if parent_id not in nodes_of_level.keys():
continue
if "children" not in nodes_of_level[parent_id]:
nodes_of_level[parent_id]["children"] = []
# Append node to parent
nl = nodes_of_level[parent_id]["children"]
nl.append(result)
nodes_of_level[get_node_id(node)] = nl[-1]
else: # for top level nodes
tree.append(result)
nodes_of_level[get_node_id(node)] = tree[-1]
return tree
def _drilldown_query(self, nodes=None):
table = self.__class__
if not nodes:
nodes = self._base_query_obj()
return nodes.filter(self.is_ancestor_of(table, inclusive=True))
[docs] def drilldown_tree(self, session=None, json=False, json_fields=None):
""" This method generate a branch from a tree, begining with current
node.
For example:
node7.drilldown_tree()
.. code::
level Nested sets example
1 1(1)22 ---------------------
_______________|_________|_________ |
| | | | |
2 2(2)5 6(4)11 | 12(7)21 |
| ^ | ^ |
3 3(3)4 7(5)8 9(6)10 | 13(8)16 17(10)20 |
| | | |
4 | 14(9)15 18(11)19 |
| |
---------------------
Example in tests:
* :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_drilldown_tree`
"""
if not session:
session = object_session(self)
return self.get_tree(
session, json=json, json_fields=json_fields, query=self._drilldown_query
)
[docs] def path_to_root(self, session=None, order=desc):
"""Generate path from a leaf or intermediate node to the root.
For example:
node11.path_to_root()
.. code::
level Nested sets example
-----------------------------------------
1 | 1(1)22 |
________|______|_____________________ |
| | | | |
| ------+--------- | |
2 2(2)5 6(4)11 | -- 12(7)21 |
| ^ | / \ |
3 3(3)4 7(5)8 9(6)10 ---/---- \ |
13(8)16 | 17(10)20 |
| | | |
4 14(9)15 | 18(11)19 |
| |
-------------
"""
table = self.__class__
query = self._base_query_obj(session=session)
query = query.filter(table.is_ancestor_of(self, inclusive=True))
return self._base_order(query, order=order)
[docs] def get_siblings(self, include_self=False, session=None):
"""
https://github.com/uralbash/sqlalchemy_mptt/issues/64
https://django-mptt.readthedocs.io/en/latest/models.html#get-siblings-include-self-false
Creates a query containing siblings of this model
instance. Root nodes are considered to be siblings of other root
nodes.
For example:
node10.get_siblings() -> [Node(8)]
Only one node is sibling of node10
.. code::
level Nested sets example
1 1(1)22
______________|____________________
| | |
| | |
2 2(2)5 6(4)11 12(7)21
| ^ / \
3 3(3)4 7(5)8 9(6)10 / \
13(8)16 17(10)20
| |
4 14(9)15 18(11)19
"""
table = self.__class__
query = self._base_query_obj(session=session)
if self.parent_id:
query = query.filter(table.parent_id == self.parent_id)
else:
query = query.filter(table.parent_id == None)
if not include_self:
query = query.filter(self.get_pk_column() != self.get_pk_value())
return query
[docs] def get_children(self, session=None):
"""
https://github.com/uralbash/sqlalchemy_mptt/issues/64
https://github.com/django-mptt/django-mptt/blob/fd76a816e05feb5fb0fc23126d33e514460a0ead/mptt/models.py#L563
Returns a query containing the immediate children of this
model instance, in tree order.
For example:
node7.get_children() -> [Node(8), Node(10)]
.. code::
level Nested sets example
1 1(1)22
______________|____________________
| | |
| | |
2 2(2)5 6(4)11 12(7)21
| ^ / \
3 3(3)4 7(5)8 9(6)10 / \
13(8)16 17(10)20
| |
4 14(9)15 18(11)19
"""
table = self.__class__
query = self._base_query_obj(session=session)
query = query.filter(table.parent_id == self.get_pk_value())
return query
[docs] @classmethod
def rebuild_tree(cls, session, tree_id):
""" This method rebuid tree.
Args:
session (:mod:`sqlalchemy.orm.session.Session`): SQLAlchemy session
tree_id (int or str): id of tree
Example:
* :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_rebuild`
"""
session.query(cls).filter_by(tree_id=tree_id).update(
{cls.left: 0, cls.right: 0, cls.level: 0}
)
top = (
session.query(cls)
.filter_by(parent_id=None)
.filter_by(tree_id=tree_id)
.one()
)
top.left = left = 1
top.right = right = 2
top.level = level = cls.get_default_level()
def recursive(children, left, right, level):
level = level + 1
for i, node in enumerate(children):
same_level_right = children[i - 1].right
left = left + 1
if i > 0:
left = left + 1
if same_level_right:
left = same_level_right + 1
right = left + 1
node.left = left
node.right = right
parent = node.parent
j = 0
while parent:
parent.right = right + 1 + j
parent = parent.parent
j += 1
node.level = level
recursive(node.children, left, right, level)
recursive(top.children, left, right, level)
[docs] @classmethod
def rebuild(cls, session, tree_id=None):
""" This function rebuid tree.
Args:
session (:mod:`sqlalchemy.orm.session.Session`): SQLAlchemy session
Kwargs:
tree_id (int or str): id of tree, default None
Example:
* :mod:`sqlalchemy_mptt.tests.TestTree.test_rebuild`
"""
trees = session.query(cls).filter_by(parent_id=None)
if tree_id:
trees = trees.filter_by(tree_id=tree_id)
for tree in trees:
cls.rebuild_tree(session, tree.tree_id)