import unittest
import sqlalchemy as sa
from sqlalchemy_mptt.mixins import BaseNestedSets
from sqlalchemy_mptt.sqlalchemy_compat import compat_layer
from sqlalchemy_mptt.tests import (DatabaseSetupMixin, TreeTestingMixin,
failures_expected_on)
Base = compat_layer.declarative_base()
[docs]
class GenericTree(Base, BaseNestedSets):
__tablename__ = "generic"
ppk = sa.Column('idd', sa.Integer, primary_key=True)
type = sa.Column(sa.Integer, default=0)
visible = sa.Column(sa.Boolean)
sqlalchemy_mptt_pk_name = 'ppk'
__mapper_args__ = {
'polymorphic_identity': 0,
'polymorphic_on': type,
}
def __repr__(self):
return "<Node (%s)>" % self.ppk
[docs]
class SpecializedTree(GenericTree):
__tablename__ = "specialized"
ppk = sa.Column(
'idd',
sa.Integer,
sa.ForeignKey(GenericTree.ppk),
primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': 1,
}
__table_args__ = tuple()
[docs]
class TestTree(DatabaseSetupMixin, unittest.TestCase):
base = Base
[docs]
def test_create_generic(self):
self.session.add(GenericTree(ppk=1))
self.session.commit()
tree = compat_layer.get(self.session, GenericTree, 1)
self.assertEqual(tree.ppk, 1)
self.assertEqual(tree.tree_id, 1)
[docs]
def test_create_spec(self):
self.session.add(SpecializedTree(ppk=1))
self.session.commit()
tree = compat_layer.get(self.session, SpecializedTree, 1)
self.assertEqual(tree.ppk, 1)
self.assertEqual(tree.tree_id, 1)
[docs]
def test_create_delete(self):
parent = SpecializedTree(ppk=1)
child1 = SpecializedTree(ppk=2, parent=parent)
child2 = GenericTree(ppk=3, parent=parent)
GenericTree(ppk=4, parent=child2)
SpecializedTree(ppk=5, parent=child2)
self.session.add(parent)
self.session.commit()
tree = compat_layer.get(self.session, SpecializedTree, 1)
self.assertEqual(tree.ppk, 1)
self.assertEqual(tree.tree_id, 1)
self.session.delete(child1)
self.session.commit()
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 2))
self.session.delete(child2)
self.session.commit()
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 3))
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 4))
self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 5))
[docs]
class TestGenericTree(TreeTestingMixin, unittest.TestCase):
base = Base
model = GenericTree
[docs]
class TestSpecializedTree(TreeTestingMixin, unittest.TestCase):
base = Base
model = SpecializedTree
[docs]
@unittest.expectedFailure
def test_rebuild(self):
# This test will always fail on specialized classes.
super().test_rebuild()
Base2 = compat_layer.declarative_base()
[docs]
class BaseInheritance(Base2):
__tablename__ = "base_inheritance"
ppk = sa.Column('idd', sa.Integer, primary_key=True)
type = sa.Column(sa.Integer, default=0)
visible = sa.Column(sa.Boolean)
__mapper_args__ = {
'polymorphic_identity': 0,
'polymorphic_on': type,
}
def __repr__(self):
return "<Node (%s)>" % self.ppk
[docs]
class InheritanceTree(BaseInheritance, BaseNestedSets):
__tablename__ = "inheriance_tree"
ppk = sa.Column('idd', sa.Integer, sa.ForeignKey(BaseInheritance.ppk),
primary_key=True)
sqlalchemy_mptt_pk_name = 'ppk'
__mapper_args__ = {
'polymorphic_identity': 1,
}
[docs]
class TestInheritanceTree(TreeTestingMixin, unittest.TestCase):
base = Base2
model = InheritanceTree
[docs]
@failures_expected_on(sqlalchemy_versions=['1.0', '1.1', '1.2', '1.3'])
def test_rebuild(self):
super().test_rebuild()